From 53384879a9011593f6add237aa0fa55d47dd3b2c Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 23 Aug 2020 12:54:58 +0200 Subject: [PATCH] some cleanup and tests --- .coverage | Bin 0 -> 53248 bytes ormar/fields/base.py | 20 +-- ormar/fields/decorators.py | 27 ---- ormar/fields/foreign_key.py | 43 ++--- ormar/fields/model_fields.py | 288 +++++++++++++++++---------------- ormar/models/fakepydantic.py | 153 +++++++++--------- ormar/models/metaclass.py | 87 ++++------ ormar/models/model.py | 12 +- ormar/queryset/query.py | 74 +++++---- ormar/queryset/queryset.py | 18 +-- ormar/relations.py | 7 +- scripts/clean.sh | 0 scripts/publish.sh | 0 scripts/test.sh | 0 tests/test_foreign_keys.py | 2 + tests/test_model_definition.py | 13 +- tests/test_models.py | 22 +++ 17 files changed, 370 insertions(+), 396 deletions(-) create mode 100644 .coverage delete mode 100644 ormar/fields/decorators.py mode change 100644 => 100755 scripts/clean.sh mode change 100644 => 100755 scripts/publish.sh mode change 100644 => 100755 scripts/test.sh diff --git a/.coverage b/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..f3e5fafb45eb490917ec4b248da6544fb15b365f GIT binary patch literal 53248 zcmeI4du$xV8NhdM&v(Apj^AhEgsh5M=i&1+A!#0jw2ef7Dutp75vR-gyyt9kcYC?r zb7IA@WuT=})PIUv)beVjQi>{71&Y#=5Lco=TSTHlDJW^|5)kI)xaAScB=+{3*<0UT z;IF9qdHw0g2 z^TLkXIe?hY+b_23L2vk~tN6PPm`Yb%HPcE|P1|AU>$h&%eA^c3w#{F=d5dJHNz42aY>|kx zSxW1BrIebMVp?34v}CsyRZK0}Bbhr@CwInByXh4V(%_m$Gd44Mx+iPhAZ4mOP(&)N zC6siZv|H_KVi~fvnd&~%PJlvGttUxyNXxR~%cZm$Q`2fPsv5Qy%e3z0{^fTM`n}=m zYW`l8wMj`w;S(rm4J726vNZ+;MfG$y)xImOB%?c3qe)U`TNu?r&wZxd+O(c%F?A`f zC0R3*hGuGdQd0M+(Tu5f&)L8x3eDL7UVdw!tUGCg?zAPdb>-2?3C|&!&7KvGfpSUR z!I~4KRf90oIhjdx0rFj$M2dC}NZ1vddNIjjvZkJGY{`8yM_=aJThN^ja@wP|R3@qI z$*8lP-z4QEn^-fFdXq#O0T%G(fG7ObYMukDRQDNs;_$;uRx+l}j%Db4xebEgJ(cC& z@Y1FHy?dCUP&{&%W*W0Qz5=6}O)g?O^n7uFZ-EIpJvqlT>YPfb zijhe>6Uls&q$81=C7m&nJp;{&cs=3G4aJ(1%_p~IcglynZg04@mLIZ51f7@Yx2&KM zG$l6#F1Qesu6{}}cS;xQPLo0L#xhTMeQhztsS0u%4Lx_?>S8M2>t*y4FzM5;P@pnR z$*Fvk+860n;xL<}6fF%y7Up343Lw#p-3nw@y7Y|6?3W%g^Q70FoW6DexnqYtpBv$ps7o$zNGK@Z2Mo(`eQ^mQa z(AoVq7;Db3f}KP{Np@#1Q}pU1sS(Da0VDH*GMz%RXX!-cl;4~J z6tvWY(ru4&+822x?%k}juJNYS(kC*cyhyiu61_h~1| zZ8W4_ayk7WGqjZVhF7lShcekQXiqP)0oU?`;;5ZHbQRdf1t)W6vJK{0_>q%u9+X`k z4ti+UK<6~O2fJ(vPOj`v9RG6a;Pd1h2Or#!01`j~NB{{S0VIF~kN^@u0!RP}Ac4z| zfWW(WFZKW3MC8a;vX(4{9o&!r5!)*vRKl1&jKmZz9?8HnX6cP>t*yEAw3^b> z=0;OZCn5>WFyOhu#vOj>Q#W~qBd?H8$h&0pvhQMSAQC_VNB{{S0VIF~kN^@u0!RP} zAc4z*Ku}o1JAYxg%Y|B={(~U|ga$tQcZDk;Eb%%2Q_%bW^?~&q`3LzW*+*_9Vewt@ zNpZh;gIF0l75a6^2wfi%gKq|(2!1cPEx0J~X5h(y2?Dqw0VIF~kN^@u0!RP}Ab~53 zz?yo2%fMZ}*_(X1dwlGUTqg}>1K#_V+1tPJE?DWofE#)n!U8wApkM-)9gyMG<9J+) zsd7v)OeLj(bfiw;loF)7)i|weO>HS!CUnRycf)J-Nw#nZ2zJy6Tw=j`vTKIO@P0Y1 zZB@0v=?e?ljgd13U27rLx1vhmw$CSJs3u*+h1SqqDR9xb^RnBbovbN_9YOw;6#}Sytt5$PT!DpWK%^s76s+BX!b+ev zDikbLNS0|IvOCA^0fE~+pK=Z>cxN5dGbn=9eu3Llf=o==t)}|Am87XfDF@#I9I{pA zX}7w3ftq5P8t=~bq--BRO{=^`P0YrZ70a_skcxPUj+z3gE`{|`TiF85$ySoroF=!x z$@76z_eIr|V`(ddq9$+alUH)HtEkE6O*BzGo=Gq(UcnbF+s>a?HogB}AK1W=K5`SO z7vB?~68puiVpV7)_+;qu(B=>cnZeWKwcz*3qrq^TPX#g{fEyA(0!RP}AOR$R z1dsp{xWWjm6*BzXKlt;0me1k;*KZUC7Zxn$|A(6er4;F`|6f;?E4|A*e-$P{^>jF4B! zU&tTG&xuA9vW7H}YVlL?byyGJA@QJix0r!d0al8h(0@PxHza@rkN^@u0!RP}AOR$R z1dsp{_(BkFC zb>NyhNWGzM1akKuKmN$r{-gTn)R|-Y;X_l`KJv)u`Hx;&Gb%LHLZ-E~z4uQJ&b;`5 zF>~@k9q2jym1TZ7cdP$u?nuk^ zP2BJ0&-{T^);R zYgtrN!=mbH7FAWTsIrnp6%{NZghe9r_W4kV9R-6d3Itf>_p`|7V^MiIi@aVIc|0sC zD`Sz{%_2cyk;~h;%NB{{S0VIF~kN^@u0!RP}AOR%s1t7rQ|0h47&;Oq#|0N%iF>;!mf_DI3 zCCA8L$)CwEc?RAC_#OET`4#yEd63)>?*jY~1aLzFNB{{S0VIF~kN^@u0!RP}AOR$R z1paRVLOm?TFNA4OM}t}#)X<=s230huq(KD@2n|FUglG_?L4XE+8u(~XP6ICuJTxey Pftv;b4O}#!>;L}`MCfTC literal 0 HcmV?d00001 diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 1f55d47..67bf190 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -1,6 +1,5 @@ -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, List, Optional, TYPE_CHECKING -import pydantic import sqlalchemy from pydantic import Field @@ -10,13 +9,6 @@ if TYPE_CHECKING: # pragma no cover from ormar.models import Model -def prepare_validator(type_): - def validate_model_field(value): - return isinstance(value, type_) - - return validate_model_field - - class BaseField: __type__ = None @@ -34,13 +26,7 @@ class BaseField: server_default: Any @classmethod - def is_required(cls) -> bool: - return ( - not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key() - ) - - @classmethod - def default_value(cls): + def default_value(cls) -> Optional[Field]: if cls.is_auto_primary_key(): return Field(default=None) if cls.has_default(): @@ -52,7 +38,7 @@ class BaseField: return None @classmethod - def has_default(cls): + def has_default(cls) -> bool: return cls.default is not None or cls.server_default is not None @classmethod diff --git a/ormar/fields/decorators.py b/ormar/fields/decorators.py deleted file mode 100644 index 842e864..0000000 --- a/ormar/fields/decorators.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, TYPE_CHECKING, Type - -from ormar import ModelDefinitionError - -if TYPE_CHECKING: # pragma no cover - from ormar.fields import BaseField - - -class RequiredParams: - def __init__(self, *args: str) -> None: - self._required = list(args) - - def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: - old_init = model_field_class.__init__ - model_field_class._old_init = old_init - - def __init__(instance: "BaseField", **kwargs: Any) -> None: - super(instance.__class__, instance).__init__(**kwargs) - for arg in self._required: - if arg not in kwargs: - raise ModelDefinitionError( - f"{instance.__class__.__name__} field requires parameter: {arg}" - ) - setattr(instance, arg, kwargs.pop(arg)) - - model_field_class.__init__ = __init__ - return model_field_class diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index ae6745c..679830a 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,7 +1,6 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable +from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy -from pydantic import BaseModel import ormar # noqa I101 from ormar.exceptions import RelationshipInstanceError @@ -13,8 +12,7 @@ if TYPE_CHECKING: # pragma no cover def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": init_dict = { - **{fk.Meta.pkname: pk or -1, - '__pk_only__': True}, + **{fk.Meta.pkname: pk or -1, "__pk_only__": True}, **{ k: create_dummy_instance(v.to) for k, v in fk.Meta.model_fields.items() @@ -24,10 +22,15 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": return fk(**init_dict) -def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True, - related_name: str = None, - virtual: bool = False, - ) -> Type[object]: +def ForeignKey( + to: "Model", + *, + name: str = None, + unique: bool = False, + nullable: bool = True, + related_name: str = None, + virtual: bool = False, +) -> Type[object]: fk_string = to.Meta.tablename + "." + to.Meta.pkname to_field = to.__fields__[to.Meta.pkname] namespace = dict( @@ -43,7 +46,7 @@ def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = T index=False, pydantic_only=False, default=None, - server_default=None + server_default=None, ) return type("ForeignKey", (ForeignKeyField, BaseField), namespace) @@ -59,21 +62,21 @@ class ForeignKeyField(BaseField): yield cls.validate @classmethod - def validate(cls, v: Any) -> Any: - return v + def validate(cls, value: Any) -> Any: + return value - @property - def __type__(self) -> Type[BaseModel]: - return self.to.__pydantic_model__ + # @property + # def __type__(self) -> Type[BaseModel]: + # return self.to.__pydantic_model__ - @classmethod - def get_column_type(cls) -> sqlalchemy.Column: - to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] - return to_column.column_type + # @classmethod + # def get_column_type(cls) -> sqlalchemy.Column: + # to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] + # return to_column.column_type @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model" + cls, value: List, child: "Model" ) -> Union["Model", List["Model"]]: return [cls.expand_relationship(val, child) for val in value] @@ -109,7 +112,7 @@ class ForeignKeyField(BaseField): @classmethod def expand_relationship( - cls, value: Any, child: "Model" + cls, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index d790d65..d3f41cf 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,41 +1,43 @@ import datetime import decimal import re -from typing import Type, Any, Optional +from typing import Any, Optional, Type import pydantic import sqlalchemy from pydantic import Json -from ormar import ModelDefinitionError +from ormar import ModelDefinitionError # noqa I101 from ormar.fields.base import BaseField # noqa I101 -def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool: +def is_field_nullable( + nullable: Optional[bool], default: Any, server_default: Any +) -> bool: if nullable is None: return default is not None or server_default is not None return nullable def String( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - allow_blank: bool = False, - strip_whitespace: bool = False, - min_length: int = None, - max_length: int = None, - curtail_length: int = None, - regex: str = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[str]: if max_length is None or max_length <= 0: - raise ModelDefinitionError(f'Parameter max_length is required for field String') + raise ModelDefinitionError("Parameter max_length is required for field String") namespace = dict( __type__=str, @@ -54,26 +56,26 @@ def String( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("String", (pydantic.ConstrainedStr, BaseField), namespace) def Integer( - *, - name: str = None, - primary_key: bool = False, - autoincrement: bool = None, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[int]: namespace = dict( __type__=int, @@ -89,23 +91,23 @@ def Integer( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=autoincrement if autoincrement is not None else primary_key + autoincrement=autoincrement if autoincrement is not None else primary_key, ) return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) def Text( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - allow_blank: bool = False, - strip_whitespace: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[str]: namespace = dict( __type__=str, @@ -120,25 +122,25 @@ def Text( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) def Float( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[int]: namespace = dict( __type__=float, @@ -154,21 +156,21 @@ def Float( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) def Boolean( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[bool]: namespace = dict( __type__=bool, @@ -181,21 +183,21 @@ def Boolean( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Boolean", (int, BaseField), namespace) def DateTime( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[datetime.datetime]: namespace = dict( __type__=datetime.datetime, @@ -208,21 +210,21 @@ def DateTime( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("DateTime", (datetime.datetime, BaseField), namespace) def Date( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[datetime.date]: namespace = dict( __type__=datetime.date, @@ -235,21 +237,21 @@ def Date( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Date", (datetime.date, BaseField), namespace) def Time( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[datetime.time]: namespace = dict( __type__=datetime.time, @@ -262,21 +264,21 @@ def Time( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Time", (datetime.time, BaseField), namespace) def JSON( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[Json]: namespace = dict( __type__=pydantic.Json, @@ -289,26 +291,26 @@ def JSON( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("JSON", (pydantic.Json, BaseField), namespace) def BigInteger( - *, - name: str = None, - primary_key: bool = False, - autoincrement: bool = None, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[int]: namespace = dict( __type__=int, @@ -324,31 +326,33 @@ def BigInteger( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=autoincrement if autoincrement is not None else primary_key + autoincrement=autoincrement if autoincrement is not None else primary_key, ) return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) def Decimal( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - precision: int = None, - scale: int = None, - max_digits: int = None, - decimal_places: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None -): + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, +) -> Type[decimal.Decimal]: if precision is None or precision < 0 or scale is None or scale < 0: - raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal') + raise ModelDefinitionError( + "Parameters scale and precision are required for field Decimal" + ) namespace = dict( __type__=decimal.Decimal, @@ -368,6 +372,6 @@ def Decimal( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index c44309d..d6958fd 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -2,16 +2,17 @@ import inspect import json import uuid from typing import ( + AbstractSet, Any, - Callable, Dict, List, + Mapping, Optional, Set, TYPE_CHECKING, Type, TypeVar, - Union, AbstractSet, Mapping, + Union, ) import databases @@ -20,10 +21,9 @@ import sqlalchemy from pydantic import BaseModel import ormar # noqa I100 -from ormar import ForeignKey from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField -from ormar.models.metaclass import ModelMetaclass, ModelMeta +from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover @@ -39,7 +39,8 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): # FakePydantic inherits from list in order to be treated as # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts - __slots__ = ('_orm_id', '_orm_saved') + __slots__ = ("_orm_id", "_orm_saved") + __abstract__ = True if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] @@ -63,18 +64,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs = { - k: self.Meta.model_fields[k].expand_relationship(v, self) + k: self._convert_json( + k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps" + ) for k, v in kwargs.items() } - values, fields_set, validation_error = pydantic.validate_model( - self, kwargs - ) + values, fields_set, validation_error = pydantic.validate_model(self, kwargs) if validation_error and not pk_only: raise validation_error - object.__setattr__(self, '__dict__', values) - object.__setattr__(self, '__fields_set__', fields_set) + object.__setattr__(self, "__dict__", values) + object.__setattr__(self, "__fields_set__", fields_set) # super().__init__(**kwargs) # self.values = self.__pydantic_model__(**kwargs) @@ -82,58 +83,50 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def __del__(self) -> None: self.Meta._orm_relationship_manager.deregister(self) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: relation_key = self.get_name(title=True) + "_" + name if name in self.__slots__: object.__setattr__(self, name, value) - elif name == 'pk': - object.__setattr__(self, self.Meta.pkname, value) + elif name == "pk": + object.__setattr__(self, self.Meta.pkname, value) elif self.Meta._orm_relationship_manager.contains(relation_key, self): self.Meta.model_fields[name].expand_relationship(value, self) else: - super().__setattr__(name, value) + value = ( + self._convert_json(name, value, "dumps") + if name in self.__fields__ + else value + ) + super().__setattr__(name, value) - def __getattr__(self, item): + def __getattribute__(self, item: str) -> Any: + if item != "__fields__" and item in self.__fields__: + related = self._extract_related_model_instead_of_field(item) + if related: + return related + value = object.__getattribute__(self, item) + value = self._convert_json(item, value, "loads") + return value + return super().__getattribute__(item) + + def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]: + return self._extract_related_model_instead_of_field(item) + + def _extract_related_model_instead_of_field( + self, item: str + ) -> Optional[Union["Model", List["Model"]]]: relation_key = self.get_name(title=True) + "_" + item if self.Meta._orm_relationship_manager.contains(relation_key, self): return self.Meta._orm_relationship_manager.get(relation_key, self) - # def __setattr__(self, key: str, value: Any) -> None: - # if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'): - # return setattr(self, key, value) - # # elif key in self._extract_related_names(): - # # value = self._convert_json(key, value, op="dumps") - # # value = self.Meta.model_fields[key].expand_relationship(value, self) - # # relation_key = self.get_name(title=True) + "_" + key - # # if not self.Meta._orm_relationship_manager.contains(relation_key, self): - # # setattr(self.values, key, value) - # else: - # super().__setattr__(key, value) - - # def __getattribute__(self, key: str) -> Any: - # if key != 'Meta' and key in self.Meta.model_fields: - # relation_key = self.get_name(title=True) + "_" + key - # if self.Meta._orm_relationship_manager.contains(relation_key, self): - # return self.Meta._orm_relationship_manager.get(relation_key, self) - # item = getattr(self.__fields__, key, None) - # item = self._convert_json(key, item, op="loads") - # return item - # return super().__getattribute__(key) - def __same__(self, other: "Model") -> bool: if self.__class__ != other.__class__: # pragma no cover return False - return (self._orm_id == other._orm_id or - self.__dict__ == other.__dict__ or - (self.pk == other.pk and self.pk is not None - )) - - # def __repr__(self) -> str: # pragma no cover - # return self.values.__repr__() - - # @classmethod - # def __get_validators__(cls) -> Callable: # pragma no cover - # yield cls.__pydantic_model__.validate + return ( + self._orm_id == other._orm_id + or self.__dict__ == other.__dict__ + or (self.pk == other.pk and self.pk is not None) + ) @classmethod def get_name(cls, title: bool = False, lower: bool = True) -> str: @@ -148,10 +141,6 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def pk(self) -> Any: return getattr(self, self.Meta.pkname) - @pk.setter - def pk(self, value: Any) -> None: - setattr(self, self.Meta.pkname, value) - @property def pk_column(self) -> sqlalchemy.Column: return self.Meta.table.primary_key.columns.values()[0] @@ -160,36 +149,38 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ - def dict( - self, - *, - include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, - exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False - ) -> 'DictStrAny': # noqa: A003' - dict_instance = super().dict(include=include, - exclude=self._exclude_related_names_not_required(nested), - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none) + def dict( # noqa A003 + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False + ) -> "DictStrAny": # noqa: A003' + dict_instance = super().dict( + include=include, + exclude=self._exclude_related_names_not_required(nested), + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) for field in self._extract_related_names(): nested_model = getattr(self, field) if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list) and not isinstance( - nested_model, ormar.Model + nested_model, ormar.Model ): dict_instance[field] = [x.dict(nested=True) for x in nested_model] elif nested_model is not None: - dict_instance[field] = nested_model.dict(nested=True) + dict_instance[field] = nested_model.dict(nested=True) return dict_instance def from_dict(self, value_dict: Dict) -> None: @@ -225,19 +216,21 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def _extract_related_names(cls) -> Set: related_names = set() for name, field in cls.Meta.model_fields.items(): - if inspect.isclass(field) and issubclass( - field, ForeignKeyField - ): + if inspect.isclass(field) and issubclass(field, ForeignKeyField): related_names.add(name) return related_names @classmethod - def _exclude_related_names_not_required(cls, nested:bool=False) -> Set: + def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: if nested: return cls._extract_related_names() related_names = set() for name, field in cls.Meta.model_fields.items(): - if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable: + if ( + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and field.nullable + ): related_names.add(name) return related_names @@ -267,7 +260,7 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.Meta.model_fields.keys(): if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), ormar.Model + getattr(one, field), ormar.Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), ormar.Model): diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 56d66d9..0d621bd 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union import databases import pydantic import sqlalchemy -from pydantic import BaseConfig, create_model, Extra -from pydantic.fields import ModelField, FieldInfo +from pydantic import BaseConfig +from pydantic.fields import FieldInfo from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar.fields import BaseField @@ -29,23 +29,11 @@ class ModelMeta: _orm_relationship_manager: RelationshipManager -def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: - pydantic_fields = { - field_name: ( - base_field.__type__, - ... if base_field.is_required else base_field.default_value, - ) - for field_name, base_field in object_dict.items() - if isinstance(base_field, BaseField) - } - return pydantic_fields - - def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: child_relation_name = ( - field.to.get_name(title=True) - + "_" - + (field.related_name or (name.lower() + "s")) + field.to.get_name(title=True) + + "_" + + (field.related_name or (name.lower() + "s")) ) reverse_name = child_relation_name relation_name = name.lower().title() + "_" + field.to.get_name() @@ -61,34 +49,28 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: - # model.__fields__[child_model_name] = ModelField( - # name=child_model_name, - # type_=Optional[Union[List[child], child]], - # model_config=child.__config__, - # class_validators=child.__validators__, - # ) model.Meta.model_fields[child_model_name] = ForeignKey( child, name=child_model_name, virtual=True ) def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: columns = [] pkname = None model_fields = { field_name: field - for field_name, field in object_dict['__annotations__'].items() + for field_name, field in object_dict["__annotations__"].items() if issubclass(field, BaseField) } for field_name, field in model_fields.items(): @@ -96,7 +78,7 @@ def sqlalchemy_columns_from_model_fields( if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") if field.pydantic_only: - raise ModelDefinitionError('Primary key column cannot be pydantic only') + raise ModelDefinitionError("Primary key column cannot be pydantic only") pkname = field_name if not field.pydantic_only: columns.append(field.get_column(field_name)) @@ -112,10 +94,10 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict: if type_.name is None: type_.name = field def_value = type_.default_value() - curr_def_value = attrs.get(field, 'NONE') - if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo): + curr_def_value = attrs.get(field, "NONE") + if curr_def_value == "NONE" and isinstance(def_value, FieldInfo): attrs[field] = def_value - elif curr_def_value == 'NONE' and type_.nullable: + elif curr_def_value == "NONE" and type_.nullable: attrs[field] = FieldInfo(default=None) return attrs @@ -132,51 +114,44 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]: class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: - attrs['Config'] = get_pydantic_base_orm_config() + attrs["Config"] = get_pydantic_base_orm_config() new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) - if hasattr(new_model, 'Meta'): + if hasattr(new_model, "Meta"): - if attrs.get("__abstract__"): - return new_model - annotations = attrs.get("__annotations__") or new_model.__annotations__ - attrs["__annotations__"]= annotations + attrs["__annotations__"] = annotations attrs = populate_pydantic_default_values(attrs) tablename = name.lower() + "s" new_model.Meta.tablename = new_model.Meta.tablename or tablename - + # sqlalchemy table creation - + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, new_model.Meta.tablename + name, attrs, new_model.Meta.tablename ) - + if hasattr(new_model.Meta, "model_fields") and not pkname: - model_fields = new_model.Meta.model_fields - for fieldname, field in new_model.Meta.model_fields.items(): - if field.primary_key: - pkname=fieldname - columns = new_model.Meta.table.columns - + model_fields = new_model.Meta.model_fields + for fieldname, field in new_model.Meta.model_fields.items(): + if field.primary_key: + pkname = fieldname + columns = new_model.Meta.table.columns + if not hasattr(new_model.Meta, "table"): - new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns) - + new_model.Meta.table = sqlalchemy.Table( + new_model.Meta.tablename, new_model.Meta.metadata, *columns + ) + new_model.Meta.columns = columns new_model.Meta.pkname = pkname if not pkname: raise ModelDefinitionError("Table has to have a primary key.") - # pydantic model creation - new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - new_model.Meta.pydantic_model = create_model( - name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields - ) - new_model.Meta.model_fields = model_fields new_model = super().__new__( # type: ignore mcs, name, bases, attrs diff --git a/ormar/models/model.py b/ormar/models/model.py index 4651aa1..5cb9ed0 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -13,10 +13,10 @@ class Model(FakePydantic): @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} @@ -66,8 +66,8 @@ class Model(FakePydantic): self_fields.pop(self.Meta.pkname) expr = ( self.Meta.table.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.Meta.pkname)) + .values(**self_fields) + .where(self.pk_column == getattr(self, self.Meta.pkname)) ) result = await self.Meta.database.execute(expr) return result diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 30b739e..55e936a 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -20,12 +20,12 @@ class JoinParameters(NamedTuple): class Query: def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, ) -> None: self.query_offset = offset @@ -49,15 +49,15 @@ class Query: self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.select_from = self.table - for key in self.model_cls.Meta.model_fields: - if ( - not self.model_cls.Meta.model_fields[key].nullable - and isinstance( - self.model_cls.Meta.model_fields[key], ForeignKeyField, - ) - and key not in self._select_related - ): - self._select_related = [key] + self._select_related + # for key in self.model_cls.Meta.model_fields: + # if ( + # not self.model_cls.Meta.model_fields[key].nullable + # and isinstance( + # self.model_cls.Meta.model_fields[key], ForeignKeyField, + # ) + # and key not in self._select_related + # ): + # self._select_related = [key] + self._select_related start_params = JoinParameters( self.model_cls, "", self.table.name, self.model_cls @@ -79,7 +79,7 @@ class Query: expr = self._apply_expression_modifiers(expr) - #print(expr.compile(compile_kwargs={"literal_binds": True})) + # print(expr.compile(compile_kwargs={"literal_binds": True})) self._reset_query_parameters() return expr, self._select_related @@ -97,12 +97,12 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: Type[BaseField], field_name: str, rel_part: str ) -> bool: return issubclass(field, ForeignKeyField) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( @@ -112,19 +112,19 @@ class Query: [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] ) return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters ) -> JoinParameters: model_cls = join_params.model_cls.Meta.model_fields[part].to to_table = model_cls.Meta.table.name @@ -138,7 +138,8 @@ class Query: ( v for k, v in model_cls.Meta.model_fields.items() - if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model + if issubclass(v, ForeignKeyField) + and v.to == join_params.prev_model ), None, ).name @@ -167,27 +168,30 @@ class Query: return JoinParameters(prev_model, previous_alias, from_table, model_cls) def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, ) -> None: for field_name, field in prev_model.Meta.model_fields.items(): if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part + field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: if rel_part not in self._select_related: - new_related = "__".join(rel_part.split("__")[:-1]) if len( - rel_part.split("__")) > 1 else rel_part + new_related = ( + "__".join(rel_part.split("__")[:-1]) + if len(rel_part.split("__")) > 1 + else rel_part + ) self.auto_related.append(new_related) rel_part = "" elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part + field, parent_virtual, nested, rel_part ): - + self._extract_auto_required_relations( prev_model=field.to, rel_part=rel_part, @@ -207,7 +211,7 @@ class Query: self._select_related = new_joins + self.auto_related def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: if self.filter_clauses: if len(self.filter_clauses) == 1: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 4e5d85e..56d1ee2 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -151,9 +151,9 @@ class QuerySet: pkname = self.model_cls.Meta.pkname pk = self.model_cls.Meta.model_fields[pkname] if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) ): del new_kwargs[pkname] diff --git a/ormar/relations.py b/ormar/relations.py index 8a422d9..008e566 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -5,7 +5,6 @@ from random import choices from typing import List, TYPE_CHECKING, Union from weakref import proxy -from ormar import ForeignKey from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover @@ -22,7 +21,11 @@ class RelationshipManager: self._aliases = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str + self, + relations_key: str, + reverse_key: str, + field: ForeignKeyField, + table_name: str, ) -> None: if relations_key not in self._relations: self._relations[relations_key] = {"type": "primary"} diff --git a/scripts/clean.sh b/scripts/clean.sh old mode 100644 new mode 100755 diff --git a/scripts/publish.sh b/scripts/publish.sh old mode 100644 new mode 100755 diff --git a/scripts/test.sh b/scripts/test.sh old mode 100644 new mode 100755 diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 17e4a52..f1bd7fa 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -5,6 +5,7 @@ from pydantic import ValidationError import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError +from ormar.fields.foreign_key import ForeignKeyField from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -120,6 +121,7 @@ async def test_model_crud(): track = await Track.objects.get(title="The Bird") assert track.album.pk == album.pk + assert isinstance(track.album, ormar.Model) assert track.album.name is None await track.album.load() assert track.album.name == "Malibu" diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index 3cfd56c..ebb0619 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -75,9 +75,18 @@ def test_model_attribute_access(example): example.test = 12 assert example.test == 12 + example._orm_saved = True + assert example._orm_saved + + +def test_model_attribute_json_access(example): + example.test_json = dict(aa=12) + assert example.test_json == dict(aa=12) + + def test_non_existing_attr(example): - with pytest.raises(ValueError): - example.new_attr=12 + with pytest.raises(ValueError): + example.new_attr = 12 def test_primary_key_access_and_setting(example): diff --git a/tests/test_models.py b/tests/test_models.py index 42c6b54..f21e70c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class JsonSample(ormar.Model): + class Meta: + tablename = "jsons" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + test_json: ormar.JSON(nullable=True) + + class User(ormar.Model): class Meta: tablename = "users" @@ -56,6 +66,18 @@ def test_model_pk(): assert user.id == 1 +@pytest.mark.asyncio +async def test_json_column(): + async with database: + await JsonSample.objects.create(test_json=dict(aa=12)) + await JsonSample.objects.create(test_json='{"aa": 12}') + + items = await JsonSample.objects.all() + assert len(items) == 2 + assert items[0].test_json == dict(aa=12) + assert items[1].test_json == dict(aa=12) + + @pytest.mark.asyncio async def test_model_crud(): async with database: