From a39179bc645567fa447983ead4f0fae0ff867a5b Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 16 Aug 2020 22:27:39 +0200 Subject: [PATCH 1/5] mostly working --- .coverage | Bin 53248 -> 0 bytes ormar/fields/base.py | 108 ++++---- ormar/fields/foreign_key.py | 120 +++++---- ormar/fields/model_fields.py | 394 ++++++++++++++++++++++++---- ormar/models/fakepydantic.py | 182 +++++++++---- ormar/models/metaclass.py | 143 ++++++---- ormar/models/model.py | 62 ++--- ormar/queryset/clause.py | 18 +- ormar/queryset/query.py | 92 +++---- ormar/queryset/queryset.py | 33 +-- ormar/relations.py | 7 +- tests/test_columns.py | 22 +- tests/test_fastapi_usage.py | 24 +- tests/test_foreign_keys.py | 93 ++++--- tests/test_model_definition.py | 98 +++---- tests/test_models.py | 39 +-- tests/test_more_reallife_fastapi.py | 24 +- tests/test_same_table_joins.py | 65 ++--- 18 files changed, 988 insertions(+), 536 deletions(-) delete mode 100644 .coverage diff --git a/.coverage b/.coverage deleted file mode 100644 index bf048747f459554fda48d6da57377636b1799559..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53248 zcmeI4dvF`Y8Nly!vSht({g$1O=B5u!63emdBtXEJ(55Yq4&^aSr!hUw(#cUG=@i|` zjwgQ9JciCNo$0jGGHocdol=@Upe+L#Lco8tblNh(Qz(x(iQUp6rX{wggu#RmyWj5J zSyE!z`j3hEZlvA2z5RB-{q47V`gC`D+buW64ONP1nWSP!ZCnM%^W3$P#BrPlz5)2! zn;UkV_5s9v-hQ!N57)d&3XnTEzv~Aa`9|PtB;@9})k_n3Q^CjlpLLjp(u z2_OL^zz7WA;&%rd8u+^h45cff8W|;`X3Q|~`CGSdyM4QK`?l+D+Af)C(q&!=wopjg zCS|lt*GU;Qre@SsMAc0#E{jLkdDq=D^T=TtV?1;}@0lWE#HAYoQ)XvHLp$r_r~*pmBZk-p5ex1c+1ty_Mze z;M%qP{d<_9P&{&1+|UR%@fmf(}wwLhCBDHD)oo3Ihb@0y|qQn~Kn(qSH1+v>r91bgRwQ zS#w(WxD5$RC+u+rx34W`ItwNshZlCe!woKXaP?~bZVxjftGBGkU0_0XPcAZz+NYAL zqGvPqM6%Q*X-nj0NqdY~XP`MDw=1}9b+P7H`Q)&*Q$Fl=x`VZ~{IEG9=)6R~Wd)6( zDY+qV-i4ra_0x*6Q@T)hnhc6Jmbrpg)fQ8nsvw7H=(^`t2UGceH>00~NuPd&0+ne> zPURcbzEF>nfY~Ii#4|8tVGcI0020j@RUos{rDY9fzx0rqC%yLM^fe2}9oy_7*E*IR zhBs8CZ?zAk#+`~zCm+30>e94?s-(V28g(<`AsaMgZNJfQ|a0y-4C@j_cw3Fm8 z4XKx0PJhS@ujSpr4IB93tThJB=|$GzTAoyFwXH)(fo+_3GG`_?!#oQ=a`MfCvdhCk z56v3roM!f5mr23SW&OnQuA>eIE2+f>n8f<=GNx$a69i@bGK$c|FEO+hOM2Q z>5R5ZjTm}or=s_zA|ZIV9unKG=uuNWot=qzS7#?YpQnb>N@O=YZ3Ow!eTJ$Vot?6r zRN^UFh9?;so215&-qXY45IhdhPM^`wsO9Nj8G=t(&0GIlIDEnykNt-Qz5n0fb8+M!az77z<4GACtB!C2v01`j~ zNB{|3qy#Q+5V$Pdxm&n#m%DGr?y_~#P}boUZkfHvEAN6u8+5oycXhqM4J|8}gcSs2 zc#AlZh{seprs#%}j)Qb4C~!&%(or=*E4#d|6fKh)WS67xhJ1>x0Rn<;wE~x1wjOKw z4;fw+r?qXW5jbsmA+s@ZR;SA*r2E!a3*3&Sq;%Dw>#@)pTB-!Dd-1&NR%a)ytdZ#h z0dof=!j%FSFG(VjP+*xC5NWF@+V;gNf!e4LDOE_8DQ0%VwpkRo$WqF&p3ys51eI(E z2;7}Z<NsOu2Y%KK7rd)f=o==t)}~;O3H{wKrHNC z2C`LEUent=%kYZD)kIY9jM4>d6hq4}^~#q~&k|zQ$>ss5x5-`9F3k?L#PZA;q(ZKu zL#;rnOJP=djG${=jKQsxslY1{}I0=J}LH#HGva>1A&_YLH`r}0sotUe)6jS7V@Zn zweO_wx4uE&wIG5U5Eav~$ zHw#KB(w6@pEJ;h#T`KVZ>oy6=) z=9+^}Rjoq2G>HZOzp{AS7puhd|0_xs%5$a2Mj^7Ka<+?3|Db`!h)u$sOXkn_{{wBp z?j0uQe@8a|9#6rwyMf&x_3EVmb+Zw|9h5G&k`~X+B^Vk<;Cr?z?082 zXOMCiA8G|smjCZsj=_cge_2S7m!w_b|2vD16H}i&xpX zPaK@-@0~n&;#mLw(NS&i%v<|s_+TC6XsJ8&;{53R2M?W|IC=8Qkq+2~*4M&`J8G%u zc;D%luId<_d~tr@s`**%^b@_)$B*rs>z(|l_ubyfsnfi_2J)_|8UNAcGY=haaaF^S zSnd5x6~tez8hfF8n0of+@5_U&ZWo~y z5pSgk@n^+duZ_R|di#lUvx8@^sR}^SrGdd04$nV7H}%f6JTp;M?}xNjf6s&cFV21t zn*a0A{9EsyI(72-ukex&&RpmF6gU1r{WoWJ&wMz(<(Z-0`PZJxj-8!)Zt(1!$K{0_ zqPNcDfw;!AetIxaQ4YJGE`R!&qlqp1kHnrldc4XFDVMs(UU=lxQ19$?`}8yIq{2mW zxqkIV`@wzV1M?k69iB2;M48CV4SAe&*BOh&TwVco>jW2fZhmdR0sEIY4jnn%(J?Vo zR>nI$9?`+x|0lO`fCP{L5Yz5kE?|H|)Y zY#|ar0!RP}AOR$R1dsp{Kmter34F2%u=oGTJ@oni_sRdrDKbgkBq!h None: - self.name = None - self._populate_from_kwargs(kwargs) + column_type: sqlalchemy.Column + constraints: List = [] - def _populate_from_kwargs(self, kwargs: Dict) -> None: - self.primary_key = kwargs.pop("primary_key", False) - self.autoincrement = kwargs.pop( - "autoincrement", self.primary_key and self.__type__ == int - ) + primary_key: bool + autoincrement: bool + nullable: bool + index: bool + unique: bool + pydantic_only: bool - self.nullable = kwargs.pop("nullable", not self.primary_key) - self.default = kwargs.pop("default", None) - self.server_default = kwargs.pop("server_default", None) + default: Any + server_default: Any - self.index = kwargs.pop("index", None) - self.unique = kwargs.pop("unique", None) - - self.pydantic_only = kwargs.pop("pydantic_only", False) - if self.pydantic_only and self.primary_key: - raise ModelDefinitionError("Primary key column cannot be pydantic only.") - - @property - def is_required(self) -> bool: + @classmethod + def is_required(cls) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key() ) - @property - def default_value(self) -> Any: - default = self.default - return default() if callable(default) else default + @classmethod + def default_value(cls): + if cls.is_auto_primary_key(): + return Field(default=None) + if cls.has_default(): + default = cls.default if cls.default is not None else cls.server_default + if callable(default): + return Field(default_factory=default) + else: + return Field(default=default) + return None - @property - def has_default(self) -> bool: - return self.default is not None or self.server_default is not None + @classmethod + def has_default(cls): + return cls.default is not None or cls.server_default is not None - @property - def is_auto_primary_key(self) -> bool: - if self.primary_key: - return self.autoincrement + @classmethod + def is_auto_primary_key(cls) -> bool: + if cls.primary_key: + return cls.autoincrement return False - def get_column(self, name: str = None) -> sqlalchemy.Column: - self.name = name - constraints = self.get_constraints() + @classmethod + def get_column(cls, name: str) -> sqlalchemy.Column: return sqlalchemy.Column( - self.name, - self.get_column_type(), - *constraints, - primary_key=self.primary_key, - autoincrement=self.autoincrement, - nullable=self.nullable, - index=self.index, - unique=self.unique, - default=self.default, - server_default=self.server_default, + name, + cls.column_type, + *cls.constraints, + primary_key=cls.primary_key, + nullable=cls.nullable and not cls.primary_key, + index=cls.index, + unique=cls.unique, + default=cls.default, + server_default=cls.server_default, ) - def get_column_type(self) -> sqlalchemy.types.TypeEngine: - raise NotImplementedError() # pragma: no cover - - def get_constraints(self) -> Optional[List]: - return [] - - def expand_relationship(self, value: Any, child: "Model") -> Any: + @classmethod + def expand_relationship(cls, value: Any, child: "Model") -> Any: return value - - def __repr__(self): # pragma no cover - return str(self.__dict__) diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 77c2da7..87a4f4d 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable import sqlalchemy from pydantic import BaseModel @@ -13,87 +13,115 @@ if TYPE_CHECKING: # pragma no cover def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": init_dict = { - **{fk.__pkname__: pk or -1}, + **{fk.Meta.pkname: pk or -1, + '__pk_only__': True}, **{ k: create_dummy_instance(v.to) - for k, v in fk.__model_fields__.items() - if isinstance(v, ForeignKey) and not v.nullable and not v.virtual + for k, v in fk.Meta.model_fields.items() + if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual }, } return fk(**init_dict) -class ForeignKey(BaseField): - def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, - ) -> None: - super().__init__(nullable=nullable, name=name) - self.virtual = virtual - self.related_name = related_name - self.to = to +def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True, + related_name: str = None, + virtual: bool = False, + ) -> Type[object]: + fk_string = to.Meta.tablename + "." + to.Meta.pkname + to_field = to.__fields__[to.Meta.pkname] + namespace = dict( + to=to, + name=name, + nullable=nullable, + constraints=[sqlalchemy.schema.ForeignKey(fk_string)], + unique=unique, + column_type=to_field.type_.column_type, + related_name=related_name, + virtual=virtual, + primary_key=False, + index=False, + pydantic_only=False, + default=None, + server_default=None + ) + + return type("ForeignKey", (ForeignKeyField, BaseField), namespace) + + +class ForeignKeyField(BaseField): + to: Type["Model"] + related_name: str + virtual: bool + + @classmethod + def __get_validators__(cls) -> Callable: + yield cls.validate + + @classmethod + def validate(cls, v: Any) -> Any: + return v @property def __type__(self) -> Type[BaseModel]: return self.to.__pydantic_model__ - def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]: - fk_string = self.to.__tablename__ + "." + self.to.__pkname__ - return [sqlalchemy.schema.ForeignKey(fk_string)] - - def get_column_type(self) -> sqlalchemy.Column: - to_column = self.to.__model_fields__[self.to.__pkname__] - return to_column.get_column_type() + @classmethod + def get_column_type(cls) -> sqlalchemy.Column: + to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] + return to_column.column_type + @classmethod def _extract_model_from_sequence( - self, value: List, child: "Model" + cls, value: List, child: "Model" ) -> Union["Model", List["Model"]]: - return [self.expand_relationship(val, child) for val in value] + return [cls.expand_relationship(val, child) for val in value] - def _register_existing_model(self, value: "Model", child: "Model") -> "Model": - self.register_relation(value, child) + @classmethod + def _register_existing_model(cls, value: "Model", child: "Model") -> "Model": + cls.register_relation(value, child) return value - def _construct_model_from_dict(self, value: dict, child: "Model") -> "Model": - model = self.to(**value) - self.register_relation(model, child) + @classmethod + def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model": + model = cls.to(**value) + cls.register_relation(model, child) return model - def _construct_model_from_pk(self, value: Any, child: "Model") -> "Model": - if not isinstance(value, self.to.pk_type()): + @classmethod + def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model": + if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( - f"Relationship error - ForeignKey {self.to.__name__} " - f"is of type {self.to.pk_type()} " + f"Relationship error - ForeignKey {cls.to.__name__} " + f"is of type {cls.to.pk_type()} " f"while {type(value)} passed as a parameter." ) - model = create_dummy_instance(fk=self.to, pk=value) - self.register_relation(model, child) + model = create_dummy_instance(fk=cls.to, pk=value) + cls.register_relation(model, child) return model - def register_relation(self, model: "Model", child: "Model") -> None: - child_model_name = self.related_name or child.get_name() - model._orm_relationship_manager.add_relation( - model, child, child_model_name, virtual=self.virtual + @classmethod + def register_relation(cls, model: "Model", child: "Model") -> None: + child_model_name = cls.related_name or child.get_name() + model.Meta._orm_relationship_manager.add_relation( + model, child, child_model_name, virtual=cls.virtual ) + @classmethod def expand_relationship( - self, value: Any, child: "Model" + cls, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None constructors = { - f"{self.to.__name__}": self._register_existing_model, - "dict": self._construct_model_from_dict, - "list": self._extract_model_from_sequence, + f"{cls.to.__name__}": cls._register_existing_model, + "dict": cls._construct_model_from_dict, + "list": cls._extract_model_from_sequence, } model = constructors.get( - value.__class__.__name__, self._construct_model_from_pk + value.__class__.__name__, cls._construct_model_from_pk )(value, child) return model diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 4f9be11..30946bf 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,87 +1,373 @@ import datetime import decimal +import re +from typing import Type, Any, Optional +import pydantic import sqlalchemy from pydantic import Json +from ormar import ModelDefinitionError from ormar.fields.base import BaseField # noqa I101 -from ormar.fields.decorators import RequiredParams -@RequiredParams("length") -class String(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.String(self.length) +def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool: + if nullable is None: + return default is not None or server_default is not None + return False -class Integer(BaseField): - __type__ = int +def String( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[str]: + if max_length is None or max_length <= 0: + raise ModelDefinitionError(f'Parameter max_length is required for field String') - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Integer() + namespace = dict( + __type__=str, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + allow_blank=allow_blank, + strip_whitespace=strip_whitespace, + min_length=min_length, + max_length=max_length, + curtail_length=curtail_length, + regex=regex and re.compile(regex), + column_type=sqlalchemy.String(length=max_length), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + + return type("String", (pydantic.ConstrainedStr, BaseField), namespace) -class Text(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Text() +def Integer( + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=int, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.Integer(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=autoincrement if autoincrement is not None else primary_key + ) + return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) -class Float(BaseField): - __type__ = float +def Text( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[str]: + namespace = dict( + __type__=str, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + allow_blank=allow_blank, + strip_whitespace=strip_whitespace, + column_type=sqlalchemy.Text(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Float() + return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) -class Boolean(BaseField): - __type__ = bool - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Boolean() +def Float( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=float, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.Float(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) -class DateTime(BaseField): - __type__ = datetime.datetime - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DateTime() +def Boolean( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[bool]: + namespace = dict( + __type__=bool, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Boolean(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Boolean", (int, BaseField), namespace) -class Date(BaseField): - __type__ = datetime.date - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Date() +def DateTime( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.datetime]: + namespace = dict( + __type__=datetime.datetime, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.DateTime(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("DateTime", (datetime.datetime, BaseField), namespace) -class Time(BaseField): - __type__ = datetime.time - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Time() +def Date( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.date]: + namespace = dict( + __type__=datetime.date, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Date(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Date", (datetime.date, BaseField), namespace) -class JSON(BaseField): - __type__ = Json - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.JSON() +def Time( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.time]: + namespace = dict( + __type__=datetime.time, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Time(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Time", (datetime.time, BaseField), namespace) -class BigInteger(BaseField): - __type__ = int +def JSON( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[Json]: + namespace = dict( + __type__=pydantic.Json, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.JSON(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.BigInteger() + return type("JSON", (pydantic.Json, BaseField), namespace) -@RequiredParams("length", "precision") -class Decimal(BaseField): - __type__ = decimal.Decimal +def BigInteger( + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=int, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.BigInteger(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=autoincrement if autoincrement is not None else primary_key + ) + return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DECIMAL(self.length, self.precision) + +def Decimal( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +): + if precision is None or precision < 0 or scale is None or scale < 0: + raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal') + + namespace = dict( + __type__=decimal.Decimal, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale), + precision=precision, + scale=scale, + max_digits=max_digits, + decimal_places=decimal_places, + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index d5109b4..e31f3f7 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -11,7 +11,7 @@ from typing import ( TYPE_CHECKING, Type, TypeVar, - Union, + Union, AbstractSet, Mapping, ) import databases @@ -20,19 +20,27 @@ import sqlalchemy from pydantic import BaseModel import ormar # noqa I100 +from ormar import ForeignKey from ormar.fields import BaseField -from ormar.models.metaclass import ModelMetaclass +from ormar.fields.foreign_key import ForeignKeyField +from ormar.models.metaclass import ModelMetaclass, ModelMeta from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model + IntStr = Union[int, str] + DictStrAny = Dict[str, Any] + AbstractSetIntStr = AbstractSet[IntStr] + MappingIntStrAny = Mapping[IntStr, Any] -class FakePydantic(list, metaclass=ModelMetaclass): + +class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): # FakePydantic inherits from list in order to be treated as # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts - __abstract__ = True + __slots__ = ('_orm_id', '_orm_saved') + if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] __table__: sqlalchemy.Table @@ -43,62 +51,88 @@ class FakePydantic(list, metaclass=ModelMetaclass): __metadata__: sqlalchemy.MetaData __database__: databases.Database _orm_relationship_manager: RelationshipManager + Meta: ModelMeta + # noinspection PyMissingConstructor def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self._orm_id: str = uuid.uuid4().hex - self._orm_saved: bool = False - self.values: Optional[BaseModel] = None + object.__setattr__(self, "_orm_id", uuid.uuid4().hex) + object.__setattr__(self, "_orm_saved", False) + + pk_only = kwargs.pop("__pk_only__", False) if "pk" in kwargs: - kwargs[self.__pkname__] = kwargs.pop("pk") + kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs = { - k: self.__model_fields__[k].expand_relationship(v, self) + k: self.Meta.model_fields[k].expand_relationship(v, self) for k, v in kwargs.items() } - self.values = self.__pydantic_model__(**kwargs) + + values, fields_set, validation_error = pydantic.validate_model( + self, kwargs + ) + if validation_error and not pk_only: + raise validation_error + + object.__setattr__(self, '__dict__', values) + object.__setattr__(self, '__fields_set__', fields_set) + + # super().__init__(**kwargs) + # self.values = self.__pydantic_model__(**kwargs) def __del__(self) -> None: - self._orm_relationship_manager.deregister(self) + self.Meta._orm_relationship_manager.deregister(self) - def __setattr__(self, key: str, value: Any) -> None: - if key in self.__fields__: - value = self._convert_json(key, value, op="dumps") - value = self.__model_fields__[key].expand_relationship(value, self) + def __setattr__(self, name, value): + if name in self.__slots__: + object.__setattr__(self, name, value) + elif name == 'pk': + object.__setattr__(self, self.Meta.pkname, value) + relation_key = self.get_name(title=True) + "_" + name + if self.Meta._orm_relationship_manager.contains(relation_key, self): + self.Meta.model_fields[name].expand_relationship(value, self) + return + super().__setattr__(name, value) - relation_key = self.get_name(title=True) + "_" + key - if not self._orm_relationship_manager.contains(relation_key, self): - setattr(self.values, key, value) - else: - super().__setattr__(key, value) + def __getattr__(self, item): + relation_key = self.get_name(title=True) + "_" + item + if self.Meta._orm_relationship_manager.contains(relation_key, self): + return self.Meta._orm_relationship_manager.get(relation_key, self) - def __getattribute__(self, key: str) -> Any: - if key != "__fields__" and key in self.__fields__: - relation_key = self.get_name(title=True) + "_" + key - if self._orm_relationship_manager.contains(relation_key, self): - return self._orm_relationship_manager.get(relation_key, self) + # def __setattr__(self, key: str, value: Any) -> None: + # if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'): + # return setattr(self, key, value) + # # elif key in self._extract_related_names(): + # # value = self._convert_json(key, value, op="dumps") + # # value = self.Meta.model_fields[key].expand_relationship(value, self) + # # relation_key = self.get_name(title=True) + "_" + key + # # if not self.Meta._orm_relationship_manager.contains(relation_key, self): + # # setattr(self.values, key, value) + # else: + # super().__setattr__(key, value) - item = getattr(self.values, key, None) - item = self._convert_json(key, item, op="loads") - return item - return super().__getattribute__(key) - - def __eq__(self, other: "Model") -> bool: - return self.values.dict() == other.values.dict() + # def __getattribute__(self, key: str) -> Any: + # if key != 'Meta' and key in self.Meta.model_fields: + # relation_key = self.get_name(title=True) + "_" + key + # if self.Meta._orm_relationship_manager.contains(relation_key, self): + # return self.Meta._orm_relationship_manager.get(relation_key, self) + # item = getattr(self.__fields__, key, None) + # item = self._convert_json(key, item, op="loads") + # return item + # return super().__getattribute__(key) def __same__(self, other: "Model") -> bool: if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + self.__dict__ is not None and other.__dict__ is not None and self.pk == other.pk ) - def __repr__(self) -> str: # pragma no cover - return self.values.__repr__() + # def __repr__(self) -> str: # pragma no cover + # return self.values.__repr__() - @classmethod - def __get_validators__(cls) -> Callable: # pragma no cover - yield cls.__pydantic_model__.validate + # @classmethod + # def __get_validators__(cls) -> Callable: # pragma no cover + # yield cls.__pydantic_model__.validate @classmethod def get_name(cls, title: bool = False, lower: bool = True) -> str: @@ -109,25 +143,57 @@ class FakePydantic(list, metaclass=ModelMetaclass): name = name.title() return name + @property + def pk(self) -> Any: + return getattr(self, self.Meta.pkname) + + @pk.setter + def pk(self, value: Any) -> None: + setattr(self, self.Meta.pkname, value) + @property def pk_column(self) -> sqlalchemy.Column: - return self.__table__.primary_key.columns.values()[0] + return self.Meta.table.primary_key.columns.values()[0] @classmethod def pk_type(cls) -> Any: - return cls.__model_fields__[cls.__pkname__].__type__ + return cls.Meta.model_fields[cls.Meta.pkname].__type__ - def dict(self, nested=False) -> Dict: # noqa: A003 - dict_instance = self.values.dict() + def dict( + self, + *, + include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, + exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False + ) -> 'DictStrAny': # noqa: A003 + print('callin super', self.__class__) + print('to exclude', self._exclude_related_names_not_required(nested)) + dict_instance = super().dict(include=include, + exclude=self._exclude_related_names_not_required(nested), + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + print('after super') for field in self._extract_related_names(): + print(self.__class__, field, nested) nested_model = getattr(self, field) - if self.__model_fields__[field].virtual and nested: + + if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list) and not isinstance( - nested_model, ormar.Model + nested_model, ormar.Model ): + print('nested list') dict_instance[field] = [x.dict(nested=True) for x in nested_model] else: + print('instance') dict_instance[field] = ( nested_model.dict(nested=True) if nested_model is not None else {} ) @@ -155,7 +221,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): return value def _is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.__model_fields__.get(column_name).__type__ == pydantic.Json + return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json def _extract_own_model_fields(self) -> Dict: related_names = self._extract_related_names() @@ -165,22 +231,32 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def _extract_related_names(cls) -> Set: related_names = set() - for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass( + field, ForeignKeyField ): related_names.add(name) return related_names + @classmethod + def _exclude_related_names_not_required(cls, nested:bool=False) -> Set: + if nested: + return cls._extract_related_names() + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable: + related_names.add(name) + return related_names + def _extract_model_db_fields(self) -> Dict: self_fields = self._extract_own_model_fields() self_fields = { - k: v for k, v in self_fields.items() if k in self.__table__.columns + k: v for k, v in self_fields.items() if k in self.Meta.table.columns } for field in self._extract_related_names(): if getattr(self, field) is not None: self_fields[field] = getattr( - getattr(self, field), self.__model_fields__[field].to.__pkname__ + getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname ) return self_fields @@ -196,9 +272,9 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": - for field in one.__model_fields__.keys(): + for field in one.Meta.model_fields.keys(): if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), ormar.Model + getattr(one, field), ormar.Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), ormar.Model): diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 7ea2fe7..76088cc 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,12 +1,15 @@ -import copy -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union +import databases +import pydantic import sqlalchemy -from pydantic import BaseConfig, create_model -from pydantic.fields import ModelField +from pydantic import BaseConfig, create_model, Extra +from pydantic.fields import ModelField, FieldInfo from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField +from ormar.queryset import QuerySet from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover @@ -15,6 +18,17 @@ if TYPE_CHECKING: # pragma no cover relationship_manager = RelationshipManager() +class ModelMeta: + tablename: str + table: sqlalchemy.Table + metadata: sqlalchemy.MetaData + database: databases.Database + columns: List[sqlalchemy.Column] + pkname: str + model_fields: Dict[str, Union[BaseField, ForeignKey]] + _orm_relationship_manager: RelationshipManager + + def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: pydantic_fields = { field_name: ( @@ -29,9 +43,9 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: child_relation_name = ( - field.to.get_name(title=True) - + "_" - + (field.related_name or (name.lower() + "s")) + field.to.get_name(title=True) + + "_" + + (field.related_name or (name.lower() + "s")) ) reverse_name = child_relation_name relation_name = name.lower().title() + "_" + field.to.get_name() @@ -41,104 +55,125 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> def expand_reverse_relationships(model: Type["Model"]) -> None: - for model_field in model.__model_fields__.values(): - if isinstance(model_field, ForeignKey): + for model_field in model.Meta.model_fields.values(): + if issubclass(model_field, ForeignKeyField): child_model_name = model_field.related_name or model.get_name() + "s" parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( + # model.__fields__[child_model_name] = ModelField( + # name=child_model_name, + # type_=Optional[Union[List[child], child]], + # model_config=child.__config__, + # class_validators=child.__validators__, + # ) + model.Meta.model_fields[child_model_name] = ForeignKey( child, name=child_model_name, virtual=True ) def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: columns = [] pkname = None model_fields = { field_name: field - for field_name, field in object_dict.items() - if isinstance(field, BaseField) + for field_name, field in object_dict['__annotations__'].items() + if issubclass(field, BaseField) } for field_name, field in model_fields.items(): if field.primary_key: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") + if field.pydantic_only: + raise ModelDefinitionError('Primary key column cannot be pydantic only') pkname = field_name if not field.pydantic_only: columns.append(field.get_column(field_name)) - if isinstance(field, ForeignKey): + if issubclass(field, ForeignKeyField): register_relation_on_build(table_name, field, name) return pkname, columns, model_fields +def populate_pydantic_default_values(attrs: Dict) -> Dict: + for field, type_ in attrs['__annotations__'].items(): + if issubclass(type_, BaseField): + if type_.name is None: + type_.name = field + def_value = type_.default_value() + curr_def_value = attrs.get(field, 'NONE') + if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo): + attrs[field] = def_value + elif curr_def_value == 'NONE' and type_.nullable: + attrs[field] = FieldInfo(default=None) + return attrs + + def get_pydantic_base_orm_config() -> Type[BaseConfig]: class Config(BaseConfig): orm_mode = True + arbitrary_types_allowed = True + # extra = Extra.allow return Config -class ModelMetaclass(type): +class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: + + attrs['Config'] = get_pydantic_base_orm_config() new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) - if attrs.get("__abstract__"): - return new_model + if hasattr(new_model, 'Meta'): - tablename = attrs.get("__tablename__", name.lower() + "s") - attrs["__tablename__"] = tablename - metadata = attrs["__metadata__"] + if attrs.get("__abstract__"): + return new_model - # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, tablename - ) - attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) - attrs["__columns__"] = columns - attrs["__pkname__"] = pkname + attrs = populate_pydantic_default_values(attrs) - if not pkname: - raise ModelDefinitionError("Table has to have a primary key.") + tablename = name.lower() + "s" + new_model.Meta.tablename = new_model.Meta.tablename or tablename - # pydantic model creation - pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model( - name, __config__=get_pydantic_base_orm_config(), **pydantic_fields - ) - attrs["__pydantic_fields__"] = pydantic_fields - attrs["__pydantic_model__"] = pydantic_model - attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) - attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) - attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) + # sqlalchemy table creation + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, new_model.Meta.tablename + ) + new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns) + new_model.Meta.columns = columns + new_model.Meta.pkname = pkname - attrs["__model_fields__"] = model_fields - attrs["_orm_relationship_manager"] = relationship_manager + if not pkname: + breakpoint() + raise ModelDefinitionError("Table has to have a primary key.") - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) + # pydantic model creation + new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs) + new_model.Meta.pydantic_model = create_model( + name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields + ) - expand_reverse_relationships(new_model) + new_model.Meta.model_fields = model_fields + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + expand_reverse_relationships(new_model) + + new_model.Meta._orm_relationship_manager = relationship_manager + new_model.objects = QuerySet(new_model) + + # breakpoint() return new_model diff --git a/ormar/models/model.py b/ormar/models/model.py index b16d3e7..4651aa1 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -7,39 +7,39 @@ from ormar.models import FakePydantic # noqa I100 class Model(FakePydantic): - __abstract__ = True + __abstract__ = False - objects = ormar.queryset.QuerySet() + # objects = ormar.queryset.QuerySet() @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} select_related = select_related or [] - table_prefix = cls._orm_relationship_manager.resolve_relation_join( - previous_table, cls.__table__.name + table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join( + previous_table, cls.Meta.table.name ) - previous_table = cls.__table__.name + previous_table = cls.Meta.table.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) - model_cls = cls.__model_fields__[first_part].to + model_cls = cls.Meta.model_fields[first_part].to child = model_cls.from_row( row, select_related=[remainder], previous_table=previous_table ) item[first_part] = child else: - model_cls = cls.__model_fields__[related].to + model_cls = cls.Meta.model_fields[related].to child = model_cls.from_row(row, previous_table=previous_table) item[related] = child - for column in cls.__table__.columns: + for column in cls.Meta.table.columns: if column.name not in item: item[column.name] = row[ f'{table_prefix + "_" if table_prefix else ""}{column.name}' @@ -47,22 +47,14 @@ class Model(FakePydantic): return cls(**item) - @property - def pk(self) -> str: - return getattr(self.values, self.__pkname__) - - @pk.setter - def pk(self, value: Any) -> None: - setattr(self.values, self.__pkname__, value) - async def save(self) -> "Model": self_fields = self._extract_model_db_fields() - if self.__model_fields__.get(self.__pkname__).autoincrement: - self_fields.pop(self.__pkname__, None) - expr = self.__table__.insert() + if self.Meta.model_fields.get(self.Meta.pkname).autoincrement: + self_fields.pop(self.Meta.pkname, None) + expr = self.Meta.table.insert() expr = expr.values(**self_fields) - item_id = await self.__database__.execute(expr) - self.pk = item_id + item_id = await self.Meta.database.execute(expr) + setattr(self, self.Meta.pkname, item_id) return self async def update(self, **kwargs: Any) -> int: @@ -71,23 +63,23 @@ class Model(FakePydantic): self.from_dict(new_values) self_fields = self._extract_model_db_fields() - self_fields.pop(self.__pkname__) + self_fields.pop(self.Meta.pkname) expr = ( - self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) + self.Meta.table.update() + .values(**self_fields) + .where(self.pk_column == getattr(self, self.Meta.pkname)) ) - result = await self.__database__.execute(expr) + result = await self.Meta.database.execute(expr) return result async def delete(self) -> int: - expr = self.__table__.delete() - expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) - result = await self.__database__.execute(expr) + expr = self.Meta.table.delete() + expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) + result = await self.Meta.database.execute(expr) return result async def load(self) -> "Model": - expr = self.__table__.select().where(self.pk_column == self.pk) - row = await self.__database__.fetch_one(expr) + expr = self.Meta.table.select().where(self.pk_column == self.pk) + row = await self.Meta.database.fetch_one(expr) self.from_dict(dict(row)) return self diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 3d5d14b..dc94e6f 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -32,7 +32,7 @@ class QueryClause: self.filter_clauses = filter_clauses self.model_cls = model_cls - self.table = self.model_cls.__table__ + self.table = self.model_cls.Meta.table def filter( # noqa: A003 self, **kwargs: Any @@ -41,7 +41,7 @@ class QueryClause: select_related = list(self._select_related) if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ + pk_name = self.model_cls.Meta.pkname kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): @@ -65,8 +65,8 @@ class QueryClause: related_parts, select_related ) - table = model_cls.__table__ - column = model_cls.__table__.columns[field_name] + table = model_cls.Meta.table + column = model_cls.Meta.table.columns[field_name] else: op = "exact" @@ -106,12 +106,12 @@ class QueryClause: # Walk the relationships to the actual model class # against which the comparison is being made. - previous_table = model_cls.__tablename__ + previous_table = model_cls.Meta.tablename for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - manager = model_cls._orm_relationship_manager + current_table = model_cls.Meta.model_fields[part].to.Meta.tablename + manager = model_cls.Meta._orm_relationship_manager table_prefix = manager.resolve_relation_join(previous_table, current_table) - model_cls = model_cls.__model_fields__[part].to + model_cls = model_cls.Meta.model_fields[part].to previous_table = current_table return select_related, table_prefix, model_cls @@ -128,7 +128,7 @@ class QueryClause: clause_text = str( clause.compile( - dialect=self.model_cls.__database__._backend._dialect, + dialect=self.model_cls.Meta.database._backend._dialect, compile_kwargs={"literal_binds": True}, ) ) diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 202b249..798502a 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -4,8 +4,8 @@ import sqlalchemy from sqlalchemy import text import ormar # noqa I100 -from ormar import ForeignKey from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -20,12 +20,12 @@ class JoinParameters(NamedTuple): class Query: def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, ) -> None: self.query_offset = offset @@ -34,7 +34,7 @@ class Query: self.filter_clauses = filter_clauses self.model_cls = model_cls - self.table = self.model_cls.__table__ + self.table = self.model_cls.Meta.table self.auto_related = [] self.used_aliases = [] @@ -46,16 +46,16 @@ class Query: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] + self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.select_from = self.table - for key in self.model_cls.__model_fields__: + for key in self.model_cls.Meta.model_fields: if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], ormar.fields.ForeignKey, - ) - and key not in self._select_related + not self.model_cls.Meta.model_fields[key].nullable + and isinstance( + self.model_cls.Meta.model_fields[key], ForeignKeyField, + ) + and key not in self._select_related ): self._select_related = [key] + self._select_related @@ -79,7 +79,7 @@ class Query: expr = self._apply_expression_modifiers(expr) - # print(expr.compile(compile_kwargs={"literal_binds": True})) + print(expr.compile(compile_kwargs={"literal_binds": True})) self._reset_query_parameters() return expr, self._select_related @@ -97,12 +97,12 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: BaseField, field_name: str, rel_part: str ) -> bool: - return isinstance(field, ForeignKey) and field_name not in rel_part + return issubclass(field, ForeignKeyField) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( @@ -112,39 +112,39 @@ class Query: [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] ) return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters ) -> JoinParameters: - model_cls = join_params.model_cls.__model_fields__[part].to - to_table = model_cls.__table__.name + model_cls = join_params.model_cls.Meta.model_fields[part].to + to_table = model_cls.Meta.table.name - alias = model_cls._orm_relationship_manager.resolve_relation_join( + alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join( join_params.from_table, to_table ) if alias not in self.used_aliases: - if join_params.prev_model.__model_fields__[part].virtual: + if join_params.prev_model.Meta.model_fields[part].virtual: to_key = next( ( v - for k, v in model_cls.__model_fields__.items() - if isinstance(v, ForeignKey) and v.to == join_params.prev_model + for k, v in model_cls.Meta.model_fields.items() + if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model ), None, ).name - from_key = model_cls.__pkname__ + from_key = model_cls.Meta.pkname else: - to_key = model_cls.__pkname__ + to_key = model_cls.Meta.pkname from_key = part on_clause = self.on_clause( @@ -157,8 +157,8 @@ class Query: self.select_from = sqlalchemy.sql.outerjoin( self.select_from, target_table, on_clause ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) - self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) + self.columns.extend(self.prefixed_columns(alias, model_cls.Meta.table)) self.used_aliases.append(alias) previous_alias = alias @@ -167,24 +167,28 @@ class Query: return JoinParameters(prev_model, previous_alias, from_table, model_cls) def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, ) -> None: - for field_name, field in prev_model.__model_fields__.items(): + for field_name, field in prev_model.Meta.model_fields.items(): if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part + field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: + print('add', rel_part, field) if rel_part not in self._select_related: - self.auto_related.append("__".join(rel_part.split("__")[:-1])) + new_related = "__".join(rel_part.split("__")[:-1]) if len( + rel_part.split("__")) > 1 else rel_part + self.auto_related.append(new_related) rel_part = "" elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part + field, parent_virtual, nested, rel_part ): + print('deeper', rel_part, field, field.to) self._extract_auto_required_relations( prev_model=field.to, rel_part=rel_part, @@ -204,7 +208,7 @@ class Query: self._select_related = new_joins + self.auto_related def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: if self.filter_clauses: if len(self.filter_clauses) == 1: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 65192b8..4e5d85e 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -33,11 +33,11 @@ class QuerySet: @property def database(self) -> databases.Database: - return self.model_cls.__database__ + return self.model_cls.Meta.database @property def table(self) -> sqlalchemy.Table: - return self.model_cls.__table__ + return self.model_cls.Meta.table def build_select_expression(self) -> sqlalchemy.sql.select: qry = Query( @@ -148,12 +148,12 @@ class QuerySet: new_kwargs = dict(**kwargs) # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.__model_fields__[pkname] + pkname = self.model_cls.Meta.pkname + pk = self.model_cls.Meta.model_fields[pkname] if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) ): del new_kwargs[pkname] @@ -163,11 +163,11 @@ class QuerySet: if isinstance(new_kwargs.get(field), ormar.Model): new_kwargs[field] = getattr( new_kwargs.get(field), - self.model_cls.__model_fields__[field].to.__pkname__, + self.model_cls.Meta.model_fields[field].to.Meta.pkname, ) else: new_kwargs[field] = new_kwargs.get(field).get( - self.model_cls.__model_fields__[field].to.__pkname__ + self.model_cls.Meta.model_fields[field].to.Meta.pkname ) # Build the insert expression. @@ -176,5 +176,6 @@ class QuerySet: # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) + pk = await self.database.execute(expr) + setattr(instance, self.model_cls.Meta.pkname, pk) return instance diff --git a/ormar/relations.py b/ormar/relations.py index 45c3ddd..8a422d9 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -6,6 +6,7 @@ from typing import List, TYPE_CHECKING, Union from weakref import proxy from ormar import ForeignKey +from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar.models import FakePydantic, Model @@ -21,14 +22,14 @@ class RelationshipManager: self._aliases = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str ) -> None: if relations_key not in self._relations: self._relations[relations_key] = {"type": "primary"} - self._aliases[f"{table_name}_{field.to.__tablename__}"] = get_table_alias() + self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() if reverse_key not in self._relations: self._relations[reverse_key] = {"type": "reverse"} - self._aliases[f"{field.to.__tablename__}_{table_name}"] = get_table_alias() + self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() def deregister(self, model: "FakePydantic") -> None: for rel_type in self._relations.keys(): diff --git a/tests/test_columns.py b/tests/test_columns.py index e75cb0a..c8c9d3b 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -16,17 +16,19 @@ def time(): class Example(ormar.Model): - __tablename__ = "example" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "example" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - created = ormar.DateTime(default=datetime.datetime.now) - created_day = ormar.Date(default=datetime.date.today) - created_time = ormar.Time(default=time) - description = ormar.Text(nullable=True) - value = ormar.Float(nullable=True) - data = ormar.JSON(default={}) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=200, default='aaa') + created: ormar.DateTime(default=datetime.datetime.now) + created_day: ormar.Date(default=datetime.date.today) + created_time: ormar.Time(default=time) + description: ormar.Text(nullable=True) + value: ormar.Float(nullable=True) + data: ormar.JSON(default={}) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 25b1310..f7f2625 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -13,22 +13,24 @@ metadata = sqlalchemy.MetaData() class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Item(ormar.Model): - __tablename__ = "items" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "items" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) @app.post("/items/", response_model=Item) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index bffa783..17e4a52 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,6 +1,7 @@ import databases import pytest import sqlalchemy +from pydantic import ValidationError import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError @@ -11,62 +12,68 @@ metadata = sqlalchemy.MetaData() class Album(ormar.Model): - __tablename__ = "album" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "albums" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Track(ormar.Model): - __tablename__ = "track" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "tracks" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album) - title = ormar.String(length=100) - position = ormar.Integer() + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album) + title: ormar.String(max_length=100) + position: ormar.Integer() class Cover(ormar.Model): - __tablename__ = "covers" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "covers" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album, related_name="cover_pictures") - title = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album, related_name="cover_pictures") + title: ormar.String(max_length=100) class Organisation(ormar.Model): - __tablename__ = "org" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "org" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - ident = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + ident: ormar.String(max_length=100) class Team(ormar.Model): - __tablename__ = "team" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "teams" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - org = ormar.ForeignKey(Organisation) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + org: ormar.ForeignKey(Organisation) + name: ormar.String(max_length=100) class Member(ormar.Model): - __tablename__ = "member" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "members" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - team = ormar.ForeignKey(Team) - email = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + team: ormar.ForeignKey(Team) + email: ormar.String(max_length=100) @pytest.fixture(autouse=True, scope="module") @@ -171,8 +178,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -180,8 +187,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -223,8 +230,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -243,8 +250,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index f6dc722..adf20ad 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -1,4 +1,5 @@ import datetime +import decimal import pydantic import pytest @@ -12,19 +13,21 @@ metadata = sqlalchemy.MetaData() class ExampleModel(Model): - __tablename__ = "example" - __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) - test_text = fields.Text(default="") - test_bool = fields.Boolean(nullable=False) - test_float = fields.Float() - test_datetime = fields.DateTime(default=datetime.datetime.now) - test_date = fields.Date(default=datetime.date.today) - test_time = fields.Time(default=datetime.time) - test_json = fields.JSON(default={}) - test_bigint = fields.BigInteger(default=0) - test_decimal = fields.Decimal(length=10, precision=2) + class Meta: + tablename = "example" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) + test_text: fields.Text(default="") + test_bool: fields.Boolean(nullable=False) + test_float: fields.Float() = None + test_datetime: fields.DateTime(default=datetime.datetime.now) + test_date: fields.Date(default=datetime.date.today) + test_time: fields.Time(default=datetime.time) + test_json: fields.JSON(default={}) + test_bigint: fields.BigInteger(default=0) + test_decimal: fields.Decimal(scale=10, precision=2) fields_to_check = [ @@ -41,15 +44,17 @@ fields_to_check = [ class ExampleModel2(Model): - __tablename__ = "example2" - __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) + class Meta: + tablename = "example2" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) @pytest.fixture() def example(): - return ExampleModel(pk=1, test_string="test", test_bool=True) + return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)) def test_not_nullable_field_is_required(): @@ -83,60 +88,65 @@ def test_primary_key_access_and_setting(example): def test_pydantic_model_is_created(example): - assert issubclass(example.values.__class__, pydantic.BaseModel) - assert all([field in example.values.__fields__ for field in fields_to_check]) - assert example.values.test == 1 + assert issubclass(example.__class__, pydantic.BaseModel) + assert all([field in example.__fields__ for field in fields_to_check]) + assert example.test == 1 def test_sqlalchemy_table_is_created(example): - assert issubclass(example.__table__.__class__, sqlalchemy.Table) - assert all([field in example.__table__.columns for field in fields_to_check]) + assert issubclass(example.Meta.table.__class__, sqlalchemy.Table) + assert all([field in example.Meta.table.columns for field in fields_to_check]) def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - test_string = fields.String(length=250) + class Meta: + tablename = "example3" + metadata = metadata + + test_string: fields.String(max_length=250) def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - id = fields.Integer(primary_key=True) - test_string = fields.String(length=250, primary_key=True) + class Meta: + tablename = "example3" + metadata = metadata + + id: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250, primary_key=True) def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Integer(primary_key=True, pydantic_only=True) + class Meta: + tablename = "example4" + metadata = metadata + + test: fields.Integer(primary_key=True, pydantic_only=True) def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Decimal(primary_key=True) + class Meta: + tablename = "example5" + metadata = metadata + + test: fields.Decimal(primary_key=True) def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.String(primary_key=True) + class Meta: + tablename = "example6" + metadata = metadata + + test: fields.String(primary_key=True) def test_json_conversion_in_model(): diff --git a/tests/test_models.py b/tests/test_models.py index 69d421c..42c6b54 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ import databases +import pydantic import pytest import sqlalchemy @@ -11,23 +12,25 @@ metadata = sqlalchemy.MetaData() class User(ormar.Model): - __tablename__ = "users" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "users" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100, default='') class Product(ormar.Model): - __tablename__ = "product" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "product" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - rating = ormar.Integer(minimum=1, maximum=5) - in_stock = ormar.Boolean(default=False) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + rating: ormar.Integer(minimum=1, maximum=5) + in_stock: ormar.Boolean(default=False) @pytest.fixture(autouse=True, scope="module") @@ -39,12 +42,12 @@ def create_test_database(): def test_model_class(): - assert list(User.__model_fields__.keys()) == ["id", "name"] - assert isinstance(User.__model_fields__["id"], ormar.Integer) - assert User.__model_fields__["id"].primary_key is True - assert isinstance(User.__model_fields__["name"], ormar.String) - assert User.__model_fields__["name"].length == 100 - assert isinstance(User.__table__, sqlalchemy.Table) + assert list(User.Meta.model_fields.keys()) == ["id", "name"] + assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt) + assert User.Meta.model_fields["id"].primary_key is True + assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr) + assert User.Meta.model_fields["name"].max_length == 100 + assert isinstance(User.Meta.table, sqlalchemy.Table) def test_model_pk(): diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index d7670cc..53ff9d4 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -30,22 +30,24 @@ async def shutdown() -> None: class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Item(ormar.Model): - __tablename__ = "items" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "items" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 60ddddc..166c943 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -12,53 +12,58 @@ metadata = sqlalchemy.MetaData() class Department(ormar.Model): - __tablename__ = "departments" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "departments" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True, autoincrement=False) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True, autoincrement=False) + name: ormar.String(max_length=100) class SchoolClass(ormar.Model): - __tablename__ = "schoolclasses" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "schoolclasses" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - department = ormar.ForeignKey(Department, nullable=False) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + department: ormar.ForeignKey(Department, nullable=False) class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Student(ormar.Model): - __tablename__ = "students" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "students" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - schoolclass = ormar.ForeignKey(SchoolClass) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) class Teacher(ormar.Model): - __tablename__ = "teachers" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "teachers" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - schoolclass = ormar.ForeignKey(SchoolClass) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) @pytest.fixture(scope="module") From 0b156caf0a2d967e0b9dea6f8399b0b8f58237b1 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 19 Aug 2020 18:40:57 +0700 Subject: [PATCH 2/5] version with pydantic inheritance passing all the tests --- .codecov.yml | 18 ++ .flake8 | 6 + .gitignore | 8 + .travis.yml | 19 ++ LICENSE.md | 21 ++ README.md | 208 ++++++++++++++++ docs/fastapi.md | 0 docs/fields.md | 206 +++++++++++++++ docs/index.md | 213 ++++++++++++++++ docs/models.md | 179 +++++++++++++ docs/pydantic.md | 0 docs/queries.md | 156 ++++++++++++ docs/relations.md | 206 +++++++++++++++ docs_src/fields/docs001.py | 36 +++ docs_src/fields/docs002.py | 36 +++ docs_src/fields/docs003.py | 41 +++ docs_src/models/docs001.py | 16 ++ docs_src/models/docs002.py | 19 ++ docs_src/models/docs003.py | 33 +++ docs_src/models/docs004.py | 22 ++ docs_src/models/docs005.py | 51 ++++ docs_src/models/docs006.py | 41 +++ docs_src/models/docs007.py | 22 ++ docs_src/relations/docs001.py | 26 ++ docs_src/relations/docs002.py | 39 +++ docs_src/relations/docs003.py | 44 ++++ mkdocs.yml | 29 +++ ormar/__init__.py | 37 +++ ormar/exceptions.py | 26 ++ ormar/fields/__init__.py | 31 +++ ormar/fields/base.py | 80 ++++++ ormar/fields/decorators.py | 27 ++ ormar/fields/foreign_key.py | 127 ++++++++++ ormar/fields/model_fields.py | 373 ++++++++++++++++++++++++++++ ormar/models/__init__.py | 4 + ormar/models/fakepydantic.py | 290 +++++++++++++++++++++ ormar/models/metaclass.py | 196 +++++++++++++++ ormar/models/model.py | 85 +++++++ ormar/queryset/__init__.py | 3 + ormar/queryset/clause.py | 178 +++++++++++++ ormar/queryset/query.py | 236 ++++++++++++++++++ ormar/queryset/queryset.py | 181 ++++++++++++++ ormar/relations.py | 104 ++++++++ requirements.txt | 21 ++ scripts/clean.sh | 19 ++ scripts/publish.sh | 23 ++ scripts/test.sh | 12 + setup.cfg | 2 + setup.py | 67 +++++ tests/__init__.py | 0 tests/settings.py | 3 + tests/test_columns.py | 60 +++++ tests/test_fastapi_usage.py | 55 ++++ tests/test_foreign_keys.py | 296 ++++++++++++++++++++++ tests/test_model_definition.py | 160 ++++++++++++ tests/test_models.py | 211 ++++++++++++++++ tests/test_more_reallife_fastapi.py | 118 +++++++++ tests/test_same_table_joins.py | 133 ++++++++++ 58 files changed, 4853 insertions(+) create mode 100644 .codecov.yml create mode 100644 .flake8 create mode 100644 .gitignore create mode 100644 .travis.yml create mode 100644 LICENSE.md create mode 100644 README.md create mode 100644 docs/fastapi.md create mode 100644 docs/fields.md create mode 100644 docs/index.md create mode 100644 docs/models.md create mode 100644 docs/pydantic.md create mode 100644 docs/queries.md create mode 100644 docs/relations.md create mode 100644 docs_src/fields/docs001.py create mode 100644 docs_src/fields/docs002.py create mode 100644 docs_src/fields/docs003.py create mode 100644 docs_src/models/docs001.py create mode 100644 docs_src/models/docs002.py create mode 100644 docs_src/models/docs003.py create mode 100644 docs_src/models/docs004.py create mode 100644 docs_src/models/docs005.py create mode 100644 docs_src/models/docs006.py create mode 100644 docs_src/models/docs007.py create mode 100644 docs_src/relations/docs001.py create mode 100644 docs_src/relations/docs002.py create mode 100644 docs_src/relations/docs003.py create mode 100644 mkdocs.yml create mode 100644 ormar/__init__.py create mode 100644 ormar/exceptions.py create mode 100644 ormar/fields/__init__.py create mode 100644 ormar/fields/base.py create mode 100644 ormar/fields/decorators.py create mode 100644 ormar/fields/foreign_key.py create mode 100644 ormar/fields/model_fields.py create mode 100644 ormar/models/__init__.py create mode 100644 ormar/models/fakepydantic.py create mode 100644 ormar/models/metaclass.py create mode 100644 ormar/models/model.py create mode 100644 ormar/queryset/__init__.py create mode 100644 ormar/queryset/clause.py create mode 100644 ormar/queryset/query.py create mode 100644 ormar/queryset/queryset.py create mode 100644 ormar/relations.py create mode 100644 requirements.txt create mode 100644 scripts/clean.sh create mode 100644 scripts/publish.sh create mode 100644 scripts/test.sh create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/settings.py create mode 100644 tests/test_columns.py create mode 100644 tests/test_fastapi_usage.py create mode 100644 tests/test_foreign_keys.py create mode 100644 tests/test_model_definition.py create mode 100644 tests/test_models.py create mode 100644 tests/test_more_reallife_fastapi.py create mode 100644 tests/test_same_table_joins.py diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..6d50415 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,18 @@ +coverage: + precision: 2 + round: down + range: "80...100" + + status: + project: yes + patch: yes + changes: yes + +comment: + layout: "reach, diff, flags, files" + behavior: default + require_changes: false # if true: only post the comment if coverage changes + require_base: no # [yes :: must have a base report to post] + require_head: yes # [yes :: must have a head report to post] + branches: # branch names that can post comment + - "master" \ No newline at end of file diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..b804c49 --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] +ignore = ANN101, ANN102, W503, S101 +max-complexity = 8 +max-line-length = 88 +import-order-style = pycharm +exclude = p38venv,.pytest_cache diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f58fa36 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +p38venv +.idea +.pytest_cache +*.pyc +*.log +test.db +dist +/ormar.egg-info/ diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..416d999 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,19 @@ +language: python + +dist: xenial + +cache: pip + +python: + - "3.6" + - "3.7" + - "3.8" + +install: + - pip install -U -r requirements.txt + +script: + - scripts/test.sh + +after_script: + - codecov \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..79eb022 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Radosław Drążkiewicz + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..c3ae7af --- /dev/null +++ b/README.md @@ -0,0 +1,208 @@ +# ORMar +

+ + Pypi version + + + Pypi version + + + Build Status + + + Coverage + + +CodeFactor + + +Codacy + +

+ +The `ormar` package is an async ORM for Python, with support for Postgres, +MySQL, and SQLite. Ormar is built with: + + * [`SQLAlchemy core`][sqlalchemy-core] for query building. + * [`databases`][databases] for cross-database async support. + * [`pydantic`][pydantic] for data validation. + +Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide +database migrations. + +The goal was to create a simple ORM that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. +Initial work was inspired by [`encode/orm`][encode/orm]. +The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. + +To avoid too high coupling with pydantic and sqlalchemy ormar uses them by **composition** rather than by **inheritance**. + +**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.1.1` + +**Note**: Use `ipython` to try this from the console, since it supports `await`. + +```python +import databases +import ormar +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Note(ormar.Model): + __tablename__ = "notes" + __database__ = database + __metadata__ = metadata + + # primary keys of type int by dafault are set to autoincrement + id = ormar.Integer(primary_key=True) + text = ormar.String(length=100) + completed = ormar.Boolean(default=False) + +# Create the database +engine = sqlalchemy.create_engine(str(database.url)) +metadata.create_all(engine) + +# .create() +await Note.objects.create(text="Buy the groceries.", completed=False) +await Note.objects.create(text="Call Mum.", completed=True) +await Note.objects.create(text="Send invoices.", completed=True) + +# .all() +notes = await Note.objects.all() + +# .filter() +notes = await Note.objects.filter(completed=True).all() + +# exact, iexact, contains, icontains, lt, lte, gt, gte, in +notes = await Note.objects.filter(text__icontains="mum").all() + +# .get() +note = await Note.objects.get(id=1) + +# .update() +await note.update(completed=True) + +# .delete() +await note.delete() + +# 'pk' always refers to the primary key +note = await Note.objects.get(pk=2) +note.pk # 2 +``` + +Ormar supports loading and filtering across foreign keys... + +```python +import databases +import ormar +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Track(ormar.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() + + +# Create some records to work with. +malibu = await Album.objects.create(name="Malibu") +await Track.objects.create(album=malibu, title="The Bird", position=1) +await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) +await Track.objects.create(album=malibu, title="The Waters", position=3) + +fantasies = await Album.objects.create(name="Fantasies") +await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) +await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + + +# Fetch an instance, without loading a foreign key relationship on it. +track = await Track.objects.get(title="The Bird") + +# We have an album instance, but it only has the primary key populated +print(track.album) # Album(id=1) [sparse] +print(track.album.pk) # 1 +print(track.album.name) # Raises AttributeError + +# Load the relationship from the database +await track.album.load() +assert track.album.name == "Malibu" + +# This time, fetch an instance, loading the foreign key relationship. +track = await Track.objects.select_related("album").get(title="The Bird") +assert track.album.name == "Malibu" + +# By default you also get a second side of the relation +# constructed as lowercase source model name +'s' (tracks in this case) +# you can also provide custom name with parameter related_name +album = await Album.objects.select_related("tracks").all() +assert len(album.tracks) == 3 + +# Fetch instances, with a filter across an FK relationship. +tracks = Track.objects.filter(album__name="Fantasies") +assert len(tracks) == 2 + +# Fetch instances, with a filter and operator across an FK relationship. +tracks = Track.objects.filter(album__name__iexact="fantasies") +assert len(tracks) == 2 + +# Limit a query +tracks = await Track.objects.limit(1).all() +assert len(tracks) == 1 +``` + +## Data types + +The following keyword arguments are supported on all field types. + + * `primary_key` + * `nullable` + * `default` + * `server_default` + * `index` + * `unique` + +All fields are required unless one of the following is set: + + * `nullable` - Creates a nullable column. Sets the default to `None`. + * `default` - Set a default value for the field. + * `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). + * `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. +Autoincrement is set by default on int primary keys. + +Available Model Fields: +* `String(length)` +* `Text()` +* `Boolean()` +* `Integer()` +* `Float()` +* `Date()` +* `Time()` +* `DateTime()` +* `JSON()` +* `BigInteger()` +* `Decimal(lenght, precision)` + +[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ +[databases]: https://github.com/encode/databases +[pydantic]: https://pydantic-docs.helpmanual.io/ +[encode/orm]: https://github.com/encode/orm/ +[alembic]: https://alembic.sqlalchemy.org/en/latest/ +[fastapi]: https://fastapi.tiangolo.com/ \ No newline at end of file diff --git a/docs/fastapi.md b/docs/fastapi.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/fields.md b/docs/fields.md new file mode 100644 index 0000000..66d4dd7 --- /dev/null +++ b/docs/fields.md @@ -0,0 +1,206 @@ +# Fields + + +There are 11 basic model field types and a special `ForeignKey` field to establish relationships between models. + +Each of the `Fields` has assigned both `sqlalchemy` column class and python type that is used to create `pydantic` model. + + +## Common Parameters + +All `Field` types have a set of common parameters. + +### primary_key + +`primary_key`: `bool` = `False` -> by default False. + +Sets the primary key column on a table, foreign keys always refer to the pk of the `Model`. + +Used in sql only. + +### autoincrement + +`autoincrement`: `bool` = `primary_key and type == int` -> defaults to True if column is a primary key and of type Integer, otherwise False. + +Can be only used with int fields. + +If a field has autoincrement it becomes optional. + +Used only in sql. + +### nullable + +`nullable`: `bool` = `not primary_key` -> defaults to False for primary key column, and True for all other. + +Specifies if field is optional or required, used both with sql and pydantic. + +!!!note + By default all `ForeignKeys` are also nullable, meaning the related `Model` is not required. + + If you change the `ForeignKey` column to `nullable`, it not only becomes required, it changes also the way in which data is loaded in queries. + + If you select `Model` without explicitly adding related `Model` assigned by not nullable `ForeignKey`, the `Model` is still gona be appended automatically, see example below. + +```Python hl_lines="24 32 33 34 35 37 38 39 40 41" +--8<-- "../docs_src/fields/docs003.py" +``` + +!!!info + If you want to know more about how you can preload related models during queries and how the relations work read the [queries][queries] and [relations][relations] sections. + + +### default + +`default`: `Any` = `None` -> defaults to None. + +A default value used if no other value is passed. + +In sql invoked on an insert, used during pydantic model definition. + +If the field has a default value it becomes optional. + +You can pass a static value or a Callable (function etc.) + +Used both in sql and pydantic. + +### server default + +`server_default`: `Any` = `None` -> defaults to None. + +A default value used if no other value is passed. + +In sql invoked on the server side so you can pass i.e. sql function (like now() wrapped in sqlalchemy text() clause). + +If the field has a server_default value it becomes optional. + +You can pass a static value or a Callable (function etc.) + +Used in sql only. + +### index + +`index`: `bool` = `False` -> by default False, + +Sets the index on a table's column. + +Used in sql only. + +### unique + +`unique`: `bool` = `False` + +Sets the unique constraint on a table's column. + +Used in sql only. + +## Fields Types + +### String + +`String(length)` has a required `length` parameter. + +* Sqlalchemy column: `sqlalchemy.String` +* Type (used for pydantic): `str` + +### Text + +`Text()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Text` +* Type (used for pydantic): `str` + +### Boolean + +`Boolean()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Boolean` +* Type (used for pydantic): `bool` + +### Integer + +`Integer()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Integer` +* Type (used for pydantic): `int` + +### BigInteger + +`BigInteger()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.BigInteger` +* Type (used for pydantic): `int` + +### Float + +`Float()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Float` +* Type (used for pydantic): `float` + +### Decimal + +`Decimal(lenght, precision)` has required `length` and `precision` parameters. + +* Sqlalchemy column: `sqlalchemy.DECIMAL` +* Type (used for pydantic): `decimal.Decimal` + +### Date + +`Date()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Date` +* Type (used for pydantic): `datetime.date` + +### Time + +`Time()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Time` +* Type (used for pydantic): `datetime.time` + +### DateTime + +`DateTime()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.DateTime` +* Type (used for pydantic): `datetime.datetime` + +### JSON + +`JSON()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.JSON` +* Type (used for pydantic): `pydantic.Json` + +### ForeignKey + +`ForeignKey(to, related_name=None)` has required parameters `to` that takes target `Model` class. + +Sqlalchemy column and Type are automatically taken from target `Model`. + +* Sqlalchemy column: class of a target `Model` primary key column +* Type (used for pydantic): type of a target `Model` primary key column + +`ForeignKey` fields are automatically registering reverse side of the relation. + +By default it's child (source) `Model` name + s, like courses in snippet below: + +```Python hl_lines="25 31" +--8<-- "../docs_src/fields/docs001.py" +``` + +But you can overwrite this name by providing `related_name` parameter like below: + +```Python hl_lines="25 30" +--8<-- "../docs_src/fields/docs002.py" +``` + +!!!tip + Since related models are coming from Relationship Manager the reverse relation on access returns list of `wekref.proxy` to avoid circular references. + +!!!info + All relations are stored in lists, but when you access parent `Model` the ormar is unpacking the value for you. + Read more in [relations][relations]. + +[relations]: ./relations.md +[queries]: ./queries.md \ No newline at end of file diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..bf035bf --- /dev/null +++ b/docs/index.md @@ -0,0 +1,213 @@ +# ORMar + +

+ + Pypi version + + + Pypi version + + + Build Status + + + Coverage + + +CodeFactor + + +Codacy + +

+ +The `ormar` package is an async ORM for Python, with support for Postgres, +MySQL, and SQLite. Ormar is built with: + + * [`SQLAlchemy core`][sqlalchemy-core] for query building. + * [`databases`][databases] for cross-database async support. + * [`pydantic`][pydantic] for data validation. + +Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide +database migrations. + +The goal was to create a simple ORM that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. +Initial work was inspired by [`encode/orm`][encode/orm]. +The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. + +**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.0.1` + +**Note**: Use `ipython` to try this from the console, since it supports `await`. + +```python +import databases +import ormar +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Note(ormar.Model): + __tablename__ = "notes" + __database__ = database + __metadata__ = metadata + + # primary keys of type int by dafault are set to autoincrement + id = ormar.Integer(primary_key=True) + text = ormar.String(length=100) + completed = ormar.Boolean(default=False) + +# Create the database +engine = sqlalchemy.create_engine(str(database.url)) +metadata.create_all(engine) + +# .create() +await Note.objects.create(text="Buy the groceries.", completed=False) +await Note.objects.create(text="Call Mum.", completed=True) +await Note.objects.create(text="Send invoices.", completed=True) + +# .all() +notes = await Note.objects.all() + +# .filter() +notes = await Note.objects.filter(completed=True).all() + +# exact, iexact, contains, icontains, lt, lte, gt, gte, in +notes = await Note.objects.filter(text__icontains="mum").all() + +# .get() +note = await Note.objects.get(id=1) + +# .update() +await note.update(completed=True) + +# .delete() +await note.delete() + +# 'pk' always refers to the primary key +note = await Note.objects.get(pk=2) +note.pk # 2 +``` + +Ormar supports loading and filtering across foreign keys... + +```python +import databases +import ormar +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Track(ormar.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() + + +# Create some records to work with. +malibu = await Album.objects.create(name="Malibu") +await Track.objects.create(album=malibu, title="The Bird", position=1) +await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) +await Track.objects.create(album=malibu, title="The Waters", position=3) + +fantasies = await Album.objects.create(name="Fantasies") +await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) +await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + + +# Fetch an instance, without loading a foreign key relationship on it. +track = await Track.objects.get(title="The Bird") + +# We have an album instance, but it only has the primary key populated +print(track.album) # Album(id=1) [sparse] +print(track.album.pk) # 1 +print(track.album.name) # Raises AttributeError + +# Load the relationship from the database +await track.album.load() +assert track.album.name == "Malibu" + +# This time, fetch an instance, loading the foreign key relationship. +track = await Track.objects.select_related("album").get(title="The Bird") +assert track.album.name == "Malibu" + +# By default you also get a second side of the relation +# constructed as lowercase source model name +'s' (tracks in this case) +# you can also provide custom name with parameter related_name +album = await Album.objects.select_related("tracks").all() +assert len(album.tracks) == 3 + +# Fetch instances, with a filter across an FK relationship. +tracks = Track.objects.filter(album__name="Fantasies") +assert len(tracks) == 2 + +# Fetch instances, with a filter and operator across an FK relationship. +tracks = Track.objects.filter(album__name__iexact="fantasies") +assert len(tracks) == 2 + +# Limit a query +tracks = await Track.objects.limit(1).all() +assert len(tracks) == 1 +``` + +## Data types + +The following keyword arguments are supported on all field types. + + * `primary_key` + * `nullable` + * `default` + * `server_default` + * `index` + * `unique` + +## Model Fields + +### Common parameters + +All fields are required unless one of the following is set: + + * `nullable` - Creates a nullable column. Sets the default to `None`. + * `default` - Set a default value for the field. + * `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). + * `primary key` - Set a primary key on a column. + * `autoincrement` - When a column is set to primary key and autoincrement is set on this column. + Autoincrement is set by default on int primary keys. + +### Fields Types + +* `String(length)` +* `Text()` +* `Boolean()` +* `Integer()` +* `Float()` +* `Date()` +* `Time()` +* `DateTime()` +* `JSON()` +* `BigInteger()` +* `Decimal(lenght, precision)` + +[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ +[databases]: https://github.com/encode/databases +[pydantic]: https://pydantic-docs.helpmanual.io/ +[encode/orm]: https://github.com/encode/orm/ +[alembic]: https://alembic.sqlalchemy.org/en/latest/ +[fastapi]: https://fastapi.tiangolo.com/ \ No newline at end of file diff --git a/docs/models.md b/docs/models.md new file mode 100644 index 0000000..5fd3cd6 --- /dev/null +++ b/docs/models.md @@ -0,0 +1,179 @@ +# Models + +## Defining models + +By defining an ormar Model you get corresponding **Pydantic model** as well as **Sqlalchemy table** for free. +They are being managed in the background and you do not have to create them on your own. + +### Model Class + +To build an ormar model you simply need to inherit a `ormar.Model` class. + +```Python hl_lines="10" +--8<-- "../docs_src/models/docs001.py" +``` + +### Defining Fields + +Next assign one or more of the [Fields][fields] as a class level variables. + +Each table **has to** have a primary key column, which you specify by setting `primary_key=True` on selected field. + +Only one primary key column is allowed. + +```Python hl_lines="14 15 16" +--8<-- "../docs_src/models/docs001.py" +``` + +!!! warning + Not assigning `primary_key` column or assigning more than one column per `Model` will raise `ModelDefinitionError` + exception. + +By default if you assign primary key to `Integer` field, the `autoincrement` option is set to true. + +You can disable by passing `autoincremant=False`. + +```Python +id = ormar.Integer(primary_key=True, autoincrement=False) +``` + +Names of the fields will be used for both the underlying `pydantic` model and `sqlalchemy` table. + +### Dependencies + +Since ormar depends on [`databases`][databases] and [`sqlalchemy-core`][sqlalchemy-core] for database connection +and table creation you need to assign each `Model` with two special parameters. + +#### Databases + +One is `Database` instance created with your database url in [sqlalchemy connection string][sqlalchemy connection string] format. + +Created instance needs to be passed to every `Model` with `__database__` parameter. + +```Python hl_lines="1 6 11" +--8<-- "../docs_src/models/docs001.py" +``` + +!!! tip + You need to create the `Database` instance **only once** and use it for all models. + You can create several ones if you want to use multiple databases. + +#### Sqlalchemy + +Second dependency is sqlalchemy `MetaData` instance. + +Created instance needs to be passed to every `Model` with `__metadata__` parameter. + +```Python hl_lines="2 7 12" +--8<-- "../docs_src/models/docs001.py" +``` + +!!! tip + You need to create the `MetaData` instance **only once** and use it for all models. + You can create several ones if you want to use multiple databases. + +### Table Names + +By default table name is created from Model class name as lowercase name plus 's'. + +You can overwrite this parameter by providing `__tablename__` argument. + +```Python hl_lines="11 12 13" +--8<-- "../docs_src/models/docs002.py" +``` + +## Initialization + +There are two ways to create and persist the `Model` instance in the database. + +!!!tip + Use `ipython` to try this from the console, since it supports `await`. + +If you plan to modify the instance in the later execution of your program you can initiate your `Model` as a normal class and later await a `save()` call. + +```Python hl_lines="19 20" +--8<-- "../docs_src/models/docs007.py" +``` + +If you want to initiate your `Model` and at the same time save in in the database use a QuerySet's method `create()`. + +Each model has a `QuerySet` initialised as `objects` parameter + +```Python hl_lines="22" +--8<-- "../docs_src/models/docs007.py" +``` + +!!!info + To read more about `QuerySets` and available methods visit [queries][queries] + +## Attributes Delegation + +Each call to `Model` fields parameter under the hood is delegated to either the `pydantic` model +or other related `Model` in case of relations. + +The fields and relations are not stored on the `Model` itself + +```Python hl_lines="31 32 33 34 35 36 37 38 39 40 41" +--8<-- "../docs_src/models/docs006.py" +``` + +!!! warning + In example above model instances are created but not persisted that's why `id` of `department` is None! + +!!!info + To read more about `ForeignKeys` and `Model` relations visit [relations][relations] + +## Internals + +Apart from special parameters defined in the `Model` during definition (tablename, metadata etc.) the `Model` provides you with useful internals. + +### Pydantic Model + +To access auto created pydantic model you can use `Model.__pydantic_model__` parameter + +For example to list model fields you can: + +```Python hl_lines="18" +--8<-- "../docs_src/models/docs003.py" +``` + +!!!tip + Note how the primary key `id` field is optional as `Integer` primary key by default has `autoincrement` set to `True`. + +!!!info + For more options visit official [pydantic][pydantic] documentation. + +### Sqlalchemy Table + +To access auto created sqlalchemy table you can use `Model.__table__` parameter + +For example to list table columns you can: + +```Python hl_lines="18" +--8<-- "../docs_src/models/docs004.py" +``` + +!!!tip + You can access table primary key name by `Course.__pkname__` + +!!!info + For more options visit official [sqlalchemy-metadata][sqlalchemy-metadata] documentation. + +### Fields Definition + +To access ormar `Fields` you can use `Model.__model_fields__` parameter + +For example to list table model fields you can: + +```Python hl_lines="18" +--8<-- "../docs_src/models/docs005.py" +``` + +[fields]: ./fields.md +[relations]: ./relations.md +[queries]: ./queries.md +[pydantic]: https://pydantic-docs.helpmanual.io/ +[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ +[sqlalchemy-metadata]: https://docs.sqlalchemy.org/en/13/core/metadata.html +[databases]: https://github.com/encode/databases +[sqlalchemy connection string]: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls \ No newline at end of file diff --git a/docs/pydantic.md b/docs/pydantic.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/queries.md b/docs/queries.md new file mode 100644 index 0000000..474a98e --- /dev/null +++ b/docs/queries.md @@ -0,0 +1,156 @@ +# Queries + +## QuerySet + +Each Model is auto registered with a QuerySet that represents the underlaying query and it's options. + +Given the Models like this + +```Python +--8<-- "../docs_src/relations/docs001.py" +``` + +we can demonstrate available methods to fetch and save the data into the database. + +### create(**kwargs) + +Creates the model instance, saves it in a database and returns the updates model (with pk populated). +The allowed kwargs are `Model` fields names and proper value types. + +```python +malibu = await Album.objects.create(name="Malibu") +await Track.objects.create(album=malibu, title="The Bird", position=1) +``` + +The alternative is a split creation and persistence of the `Model`. +```python +malibu = Album(name="Malibu") +await malibu.save() +``` + +### load() + +By default when you query a table without prefetching related models, the ormar will still construct +your related models, but populate them only with the pk value. + +```python +track = await Track.objects.get(name='The Bird') +track.album.pk # will return malibu album pk (1) +track.album.name # will return None + +# you need to actually load the data first +await track.album.load() +track.album.name # will return 'Malibu' +``` + +### get(**kwargs) + +Get's the first row from the db meeting the criteria set by kwargs. + +If no criteria set it will return the first row in db. + +Passing a criteria is actually calling filter(**kwargs) method described below. + +```python +track = await Track.objects.get(name='The Bird') +track2 = track = await Track.objects.get() +track == track2 # True since it's the only row in db +``` + +### all() + +Returns all rows from a database for given model + +```python +tracks = await Track.objects.select_related("album").all() +# will return a list of all Tracks +``` + +### filter(**kwargs) + +Allows you to filter by any `Model` attribute/field +as well as to fetch instances, with a filter across an FK relationship. + +```python +track = Track.objects.filter(name="The Bird").get() +# will return a track with name equal to 'The Bird' + +tracks = Track.objects.filter(album__name="Fantasies").all() +# will return all tracks where the related album name = 'Fantasies' +``` + +You can use special filter suffix to change the filter operands: + +* exact - like `album__name__exact='Malibu'` (exact match) +* iexact - like `album__name__iexact='malibu'` (exact match case insensitive) +* contains - like `album__name__conatins='Mal'` (sql like) +* icontains - like `album__name__icontains='mal'` (sql like case insensitive) +* in - like `album__name__in=['Malibu', 'Barclay']` (sql in) +* gt - like `position__gt=3` (sql >) +* gte - like `position__gte=3` (sql >=) +* lt - like `position__lt=3` (sql <) +* lte - like `position__lte=3` (sql <=) + +!!!note + `filter()`, `select_related()`, `limit()` and `offset()` returns a QueySet instance so you can chain them together. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + +### select_related(*args) + +Allows to prefetch related models. + +To fetch related model use `ForeignKey` names. + +To chain related `Models` relation use double underscore. + +```python +album = await Album.objects.select_related("tracks").all() +# will return album will all related tracks +``` + +You can provide a string or a list of strings + +```python +classes = await SchoolClass.objects.select_related( +["teachers__category", "students"]).all() +# will return classes with teachers and teachers categories +# as well as classes students +``` + +!!!warning + If you set `ForeignKey` field as not nullable (so required) during + all queries the not nullable `Models` will be auto prefetched, even if you do not include them in select_related. + +!!!note + `filter()`, `select_related()`, `limit()` and `offset()` returns a QueySet instance so you can chain them together. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + +### limit(int) + +You can limit the results to desired number of rows. + +```python +tracks = await Track.objects.limit(1).all() +# will return just one Track +``` + +!!!note + `filter()`, `select_related()`, `limit()` and `offset()` returns a QueySet instance so you can chain them together. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + +### offset(int) + +You can also offset the results by desired number of rows. + +```python +tracks = await Track.objects.offset(1).limit(1).all() +# will return just one Track, but this time the second one +``` + +!!!note + `filter()`, `select_related()`, `limit()` and `offset()` returns a QueySet instance so you can chain them together. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` \ No newline at end of file diff --git a/docs/relations.md b/docs/relations.md new file mode 100644 index 0000000..43a6e20 --- /dev/null +++ b/docs/relations.md @@ -0,0 +1,206 @@ +# Relations + +## Defining a relationship + +### Foreign Key + +To define a relationship you simply need to create a ForeignKey field on one `Model` and point it to another `Model`. + +```Python hl_lines="24" +--8<-- "../docs_src/relations/docs001.py" +``` + +It automatically creates an sql foreign key constraint on a underlying table as well as nested pydantic model in the definition. + + +```Python hl_lines="29 33" +--8<-- "../docs_src/relations/docs002.py" +``` + +Of course it's handled for you so you don't have to delve deep into this but you can. + +!!!tip + Note how by default the relation is optional, you can require the related `Model` by setting `nullable=False` on the `ForeignKey` field. + +### Reverse Relation + +At the same time the reverse relationship is registered automatically on parent model (target of `ForeignKey`). + +By default it's child (source) `Model` name + 's', like courses in snippet below: + +```Python hl_lines="25 31" +--8<-- "../docs_src/fields/docs001.py" +``` + +But you can overwrite this name by providing `related_name` parameter like below: + +```Python hl_lines="25 30" +--8<-- "../docs_src/fields/docs002.py" +``` + +!!!tip + Since related models are coming from Relationship Manager the reverse relation on access returns list of `wekref.proxy` to avoid circular references. + +## Relationship Manager + +!!!tip + This section is more technical so you might want to skip it if you are not interested in implementation details. + +### Need for a manager? + +Since orm uses Sqlalchemy core under the hood to prepare the queries, +the orm needs a way to uniquely identify each relationship between the tables to construct working queries. + +Imagine that you have models as following: + +```Python +--8<-- "../docs_src/relations/docs003.py" +``` + +Now imagine that you want to go from school class to student and his category and to teacher and his category. + +```Python +classes = await SchoolClass.objects.select_related( +["teachers__category", "students__category"]).all() +``` + +!!!tip + To query a chain of models use double underscores between the relation names (`ForeignKeys` or reverse `ForeignKeys`) + +!!!note + To select related models use `select_related` method from `Model` `QuerySet`. + + Note that you use relation (`ForeignKey`) names and not the table names. + +Since you join two times to the same table (categories) it won't work by default -> you would need to use aliases for category tables and columns. + +But don't worry - ormar can handle situations like this, as it uses the Relationship Manager which has it's aliases defined for all relationships. + +Each class is registered with the same instance of the RelationshipManager that you can access like this: + +```python +SchoolClass._orm_relationship_manager +``` + +It's the same object for all `Models` + +```python +print(Teacher._orm_relationship_manager == Student._orm_relationship_manager) +# will produce: True +``` + +### Table aliases + +You can even preview the alias used for any relation by passing two tables names. + +```python +print(Teacher._orm_relationship_manager.resolve_relation_join( +'students', 'categories')) +# will produce: KId1c6 (sample value) + +print(Teacher._orm_relationship_manager.resolve_relation_join( +'categories', 'students')) +# will produce: EFccd5 (sample value) +``` + +!!!note + The order that you pass the names matters -> as those are 2 different relationships depending on join order. + + As aliases are produced randomly you can be presented with different results. + +### Query automatic construction + +Ormar is using those aliases during queries to both construct a meaningful and valid sql, +as well as later use it to extract proper columns for proper nested models. + +Running a previously mentioned query to select school classes and related teachers and students: + +```Python +classes = await SchoolClass.objects.select_related( +["teachers__category", "students__category"]).all() +``` + +Will result in a query like this (run under the hood): + +```sql +SELECT schoolclasses.id, + schoolclasses.name, + schoolclasses.department, + NZc8e2_students.id as NZc8e2_id, + NZc8e2_students.name as NZc8e2_name, + NZc8e2_students.schoolclass as NZc8e2_schoolclass, + NZc8e2_students.category as NZc8e2_category, + MYfe53_categories.id as MYfe53_id, + MYfe53_categories.name as MYfe53_name, + WA49a3_teachers.id as WA49a3_id, + WA49a3_teachers.name as WA49a3_name, + WA49a3_teachers.schoolclass as WA49a3_schoolclass, + WA49a3_teachers.category as WA49a3_category, + WZa13b_categories.id as WZa13b_id, + WZa13b_categories.name as WZa13b_name +FROM schoolclasses + LEFT OUTER JOIN students NZc8e2_students ON NZc8e2_students.schoolclass = schoolclasses.id + LEFT OUTER JOIN categories MYfe53_categories ON MYfe53_categories.id = NZc8e2_students.category + LEFT OUTER JOIN teachers WA49a3_teachers ON WA49a3_teachers.schoolclass = schoolclasses.id + LEFT OUTER JOIN categories WZa13b_categories ON WZa13b_categories.id = WA49a3_teachers.category +ORDER BY schoolclasses.id, NZc8e2_students.id, MYfe53_categories.id, WA49a3_teachers.id, WZa13b_categories.id +``` + +!!!note + As mentioned before the aliases are produced dynamically so the actual result might differ. + + Note that aliases are assigned to relations and not the tables, therefore the first table is always without an alias. + +### Returning related Models + +Each object in Relationship Manager is identified by orm_id which you can preview like this + +```python +category = Category(name='Math') +print(category._orm_id) +# will produce: c76046d9410c4582a656bf12a44c892c (sample value) +``` + +Each call to related `Model` is actually coming through the Manager which stores all +the relations in a dictionary and returns related `Models` by relation type (name) and by object _orm_id. + +Since we register both sides of the relation the side registering the relation +is always registering the other side as concrete model, +while the reverse relation is a weakref.proxy to avoid circular references. + +Sounds complicated but in reality it means something like this: + +```python +test_class = await SchoolClass.objects.create(name='Test') +student = await Student.objects.create(name='John', schoolclass=test_class) +# the relation to schoolsclass from student (i.e. when you call student.schoolclass) +# is a concrete one, meaning directy relating the schoolclass `Model` object +# On the other side calling test_class.students will result in a list of wekref.proxy objects +``` + +!!!tip + To learn more about queries and available methods please review [queries][queries] section. + +All relations are kept in lists, meaning that when you access related object the Relationship Manager is +searching itself for related models and get a list of them. + +But since child to parent relation is a many to one type, +the Manager is unpacking the first (and only) related model from a list and you get an actual `Model` instance instead of a list. + +Coming from parent to child relation (one to many) you always get a list of results. + +Translating this into concrete sample, the same as above: + +```python +test_class = await SchoolClass.objects.create(name='Test') +student = await Student.objects.create(name='John', schoolclass=test_class) + +student.schoolclass # return a test_class instance extracted from relationship list +test_class.students # return a list of related wekref.proxy refering related students `Models` + +``` + +!!!tip + You can preview all relations currently registered by accessing Relationship Manager on any class/instance `Student._orm_relationship_manager._relations` + +[queries]: ./queries.md \ No newline at end of file diff --git a/docs_src/fields/docs001.py b/docs_src/fields/docs001.py new file mode 100644 index 0000000..047690d --- /dev/null +++ b/docs_src/fields/docs001.py @@ -0,0 +1,36 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Department(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Course(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + department = ormar.ForeignKey(Department) + + +department = Department(name='Science') +course = Course(name='Math', completed=False, department=department) + +print(department.courses[0]) +# Will produce: +# Course(id=None, +# name='Math', +# completed=False, +# department=Department(id=None, name='Science')) diff --git a/docs_src/fields/docs002.py b/docs_src/fields/docs002.py new file mode 100644 index 0000000..7fa6ccd --- /dev/null +++ b/docs_src/fields/docs002.py @@ -0,0 +1,36 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Department(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Course(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + department = ormar.ForeignKey(Department, related_name="my_courses") + +department = Department(name='Science') +course = Course(name='Math', completed=False, department=department) + +print(department.my_courses[0]) +# Will produce: +# Course(id=None, +# name='Math', +# completed=False, +# department=Department(id=None, name='Science')) + diff --git a/docs_src/fields/docs003.py b/docs_src/fields/docs003.py new file mode 100644 index 0000000..32ff68c --- /dev/null +++ b/docs_src/fields/docs003.py @@ -0,0 +1,41 @@ +import ormar +import databases +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Track(ormar.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album, nullable=False) + title = ormar.String(length=100) + position = ormar.Integer() + + +album = await Album.objects.create(name="Brooklyn") +await Track.objects.create(album=album, title="The Bird", position=1) + +# explicit preload of related Album Model +track = await Track.objects.select_related("album").get(title="The Bird") +assert track.album.name == 'Brooklyn' +# Will produce: True + +# even without explicit select_related if ForeignKey is not nullable, +# the Album Model is still preloaded. +track2 = await Track.objects.get(title="The Bird") +assert track2.album.name == 'Brooklyn' +# Will produce: True diff --git a/docs_src/models/docs001.py b/docs_src/models/docs001.py new file mode 100644 index 0000000..9c8f8f1 --- /dev/null +++ b/docs_src/models/docs001.py @@ -0,0 +1,16 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) diff --git a/docs_src/models/docs002.py b/docs_src/models/docs002.py new file mode 100644 index 0000000..1d63371 --- /dev/null +++ b/docs_src/models/docs002.py @@ -0,0 +1,19 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(ormar.Model): + # if you omit this parameter it will be created automatically + # as class.__name__.lower()+'s' -> "courses" in this example + __tablename__ = "my_courses" + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) diff --git a/docs_src/models/docs003.py b/docs_src/models/docs003.py new file mode 100644 index 0000000..754f6d4 --- /dev/null +++ b/docs_src/models/docs003.py @@ -0,0 +1,33 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + +print(Course.__pydantic_model__.__fields__) +""" +Will produce: +{'completed': ModelField(name='completed', + type=bool, + required=False, + default=False), + 'id': ModelField(name='id', + type=Optional[int], + required=False, + default=None), + 'name': ModelField(name='name', + type=Optional[str], + required=False, + default=None)} +""" diff --git a/docs_src/models/docs004.py b/docs_src/models/docs004.py new file mode 100644 index 0000000..36e5da0 --- /dev/null +++ b/docs_src/models/docs004.py @@ -0,0 +1,22 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + +print(Course.__table__.columns) +""" +Will produce: +['courses.id', 'courses.name', 'courses.completed'] +""" diff --git a/docs_src/models/docs005.py b/docs_src/models/docs005.py new file mode 100644 index 0000000..e1a85e8 --- /dev/null +++ b/docs_src/models/docs005.py @@ -0,0 +1,51 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + +print(Course.__model_fields__) +""" +Will produce: +{ +'id': {'name': 'id', + 'primary_key': True, + 'autoincrement': True, + 'nullable': False, + 'default': None, + 'server_default': None, + 'index': None, + 'unique': None, + 'pydantic_only': False}, +'name': {'name': 'name', + 'primary_key': False, + 'autoincrement': False, + 'nullable': True, + 'default': None, + 'server_default': None, + 'index': None, + 'unique': None, + 'pydantic_only': False, + 'length': 100}, +'completed': {'name': 'completed', + 'primary_key': False, + 'autoincrement': False, + 'nullable': True, + 'default': False, + 'server_default': None, + 'index': None, + 'unique': None, + 'pydantic_only': False} +} +""" diff --git a/docs_src/models/docs006.py b/docs_src/models/docs006.py new file mode 100644 index 0000000..b232979 --- /dev/null +++ b/docs_src/models/docs006.py @@ -0,0 +1,41 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Department(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Course(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + department = ormar.ForeignKey(Department) + + +department = Department(name="Science") +course = Course(name="Math", completed=False, department=department) + +print('name' in course.__dict__) +# False <- property name is not stored on Course instance +print(course.name) +# Math <- value returned from underlying pydantic model +print('department' in course.__dict__) +# False <- related model is not stored on Course instance +print(course.department) +# Department(id=None, name='Science') <- Department model +# returned from RelationshipManager +print(course.department.name) +# Science \ No newline at end of file diff --git a/docs_src/models/docs007.py b/docs_src/models/docs007.py new file mode 100644 index 0000000..f98e62a --- /dev/null +++ b/docs_src/models/docs007.py @@ -0,0 +1,22 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(ormar.Model): + __database__ = database + __metadata__ = metadata + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + + +course = Course(name="Painting for dummies", completed=False) +await course.save() + +await Course.objects.create(name="Painting for dummies", completed=False) diff --git a/docs_src/relations/docs001.py b/docs_src/relations/docs001.py new file mode 100644 index 0000000..53e2f5c --- /dev/null +++ b/docs_src/relations/docs001.py @@ -0,0 +1,26 @@ +import ormar +import databases +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Track(ormar.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() \ No newline at end of file diff --git a/docs_src/relations/docs002.py b/docs_src/relations/docs002.py new file mode 100644 index 0000000..ef67093 --- /dev/null +++ b/docs_src/relations/docs002.py @@ -0,0 +1,39 @@ +import ormar +import databases +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Track(ormar.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() + + +print(Track.__table__.columns['album'].__repr__()) +# Will produce: +# Column('album', Integer(), ForeignKey('album.id'), table=) + +print(Track.__pydantic_model__.__fields__['album']) +# Will produce: +# ModelField( +# name='album' +# type=Optional[Album] +# required=False +# default=None) diff --git a/docs_src/relations/docs003.py b/docs_src/relations/docs003.py new file mode 100644 index 0000000..e319e42 --- /dev/null +++ b/docs_src/relations/docs003.py @@ -0,0 +1,44 @@ +import databases +import sqlalchemy +import ormar + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class SchoolClass(ormar.Model): + __tablename__ = "schoolclasses" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Category(ormar.Model): + __tablename__ = "categories" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Student(ormar.Model): + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + schoolclass = ormar.ForeignKey(SchoolClass) + category = ormar.ForeignKey(Category) + + +class Teacher(ormar.Model): + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + schoolclass = ormar.ForeignKey(SchoolClass) + category = ormar.ForeignKey(Category) diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..d63f952 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,29 @@ +site_name: Async ORM +nav: + - Home: index.md + - Models: models.md + - Fields: fields.md + - Relations: relations.md + - Queries: queries.md + - Pydantic models: pydantic.md + - Use with Fastapi: fastapi.md +theme: + name: material + highlightjs: true + hljs_languages: + - python + palette: + primary: indigo +markdown_extensions: + - admonition + - pymdownx.superfences + - pymdownx.snippets: + base_path: docs + - pymdownx.inlinehilite + - pymdownx.highlight: + linenums: true +extra_javascript: + - https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js + - javascripts/config.js +extra_css: + - https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/default.min.css \ No newline at end of file diff --git a/ormar/__init__.py b/ormar/__init__.py new file mode 100644 index 0000000..098adf2 --- /dev/null +++ b/ormar/__init__.py @@ -0,0 +1,37 @@ +from ormar.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch +from ormar.fields import ( + BigInteger, + Boolean, + Date, + DateTime, + Decimal, + Float, + ForeignKey, + Integer, + JSON, + String, + Text, + Time, +) +from ormar.models import Model + +__version__ = "0.1.3" +__all__ = [ + "Integer", + "BigInteger", + "Boolean", + "Time", + "Text", + "String", + "JSON", + "DateTime", + "Date", + "Decimal", + "Float", + "Model", + "ModelDefinitionError", + "ModelNotSet", + "MultipleMatches", + "NoMatch", + "ForeignKey", +] diff --git a/ormar/exceptions.py b/ormar/exceptions.py new file mode 100644 index 0000000..40cfd26 --- /dev/null +++ b/ormar/exceptions.py @@ -0,0 +1,26 @@ +class AsyncOrmException(Exception): + pass + + +class ModelDefinitionError(AsyncOrmException): + pass + + +class ModelNotSet(AsyncOrmException): + pass + + +class NoMatch(AsyncOrmException): + pass + + +class MultipleMatches(AsyncOrmException): + pass + + +class QueryDefinitionError(AsyncOrmException): + pass + + +class RelationshipInstanceError(AsyncOrmException): + pass diff --git a/ormar/fields/__init__.py b/ormar/fields/__init__.py new file mode 100644 index 0000000..f6c4dc9 --- /dev/null +++ b/ormar/fields/__init__.py @@ -0,0 +1,31 @@ +from ormar.fields.base import BaseField +from ormar.fields.foreign_key import ForeignKey +from ormar.fields.model_fields import ( + BigInteger, + Boolean, + Date, + DateTime, + Decimal, + Float, + Integer, + JSON, + String, + Text, + Time, +) + +__all__ = [ + "Decimal", + "BigInteger", + "Boolean", + "Date", + "DateTime", + "String", + "JSON", + "Integer", + "Text", + "Float", + "Time", + "ForeignKey", + "BaseField", +] diff --git a/ormar/fields/base.py b/ormar/fields/base.py new file mode 100644 index 0000000..1f55d47 --- /dev/null +++ b/ormar/fields/base.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +import pydantic +import sqlalchemy +from pydantic import Field + +from ormar import ModelDefinitionError # noqa I101 + +if TYPE_CHECKING: # pragma no cover + from ormar.models import Model + + +def prepare_validator(type_): + def validate_model_field(value): + return isinstance(value, type_) + + return validate_model_field + + +class BaseField: + __type__ = None + + column_type: sqlalchemy.Column + constraints: List = [] + + primary_key: bool + autoincrement: bool + nullable: bool + index: bool + unique: bool + pydantic_only: bool + + default: Any + server_default: Any + + @classmethod + def is_required(cls) -> bool: + return ( + not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key() + ) + + @classmethod + def default_value(cls): + if cls.is_auto_primary_key(): + return Field(default=None) + if cls.has_default(): + default = cls.default if cls.default is not None else cls.server_default + if callable(default): + return Field(default_factory=default) + else: + return Field(default=default) + return None + + @classmethod + def has_default(cls): + return cls.default is not None or cls.server_default is not None + + @classmethod + def is_auto_primary_key(cls) -> bool: + if cls.primary_key: + return cls.autoincrement + return False + + @classmethod + def get_column(cls, name: str) -> sqlalchemy.Column: + return sqlalchemy.Column( + name, + cls.column_type, + *cls.constraints, + primary_key=cls.primary_key, + nullable=cls.nullable and not cls.primary_key, + index=cls.index, + unique=cls.unique, + default=cls.default, + server_default=cls.server_default, + ) + + @classmethod + def expand_relationship(cls, value: Any, child: "Model") -> Any: + return value diff --git a/ormar/fields/decorators.py b/ormar/fields/decorators.py new file mode 100644 index 0000000..842e864 --- /dev/null +++ b/ormar/fields/decorators.py @@ -0,0 +1,27 @@ +from typing import Any, TYPE_CHECKING, Type + +from ormar import ModelDefinitionError + +if TYPE_CHECKING: # pragma no cover + from ormar.fields import BaseField + + +class RequiredParams: + def __init__(self, *args: str) -> None: + self._required = list(args) + + def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: + old_init = model_field_class.__init__ + model_field_class._old_init = old_init + + def __init__(instance: "BaseField", **kwargs: Any) -> None: + super(instance.__class__, instance).__init__(**kwargs) + for arg in self._required: + if arg not in kwargs: + raise ModelDefinitionError( + f"{instance.__class__.__name__} field requires parameter: {arg}" + ) + setattr(instance, arg, kwargs.pop(arg)) + + model_field_class.__init__ = __init__ + return model_field_class diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py new file mode 100644 index 0000000..b457616 --- /dev/null +++ b/ormar/fields/foreign_key.py @@ -0,0 +1,127 @@ +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable + +import sqlalchemy +from pydantic import BaseModel + +import ormar # noqa I101 +from ormar.exceptions import RelationshipInstanceError +from ormar.fields.base import BaseField + +if TYPE_CHECKING: # pragma no cover + from ormar.models import Model + + +def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": + init_dict = { + **{fk.Meta.pkname: pk or -1, + '__pk_only__': True}, + **{ + k: create_dummy_instance(v.to) + for k, v in fk.Meta.model_fields.items() + if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual + }, + } + return fk(**init_dict) + + +def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True, + related_name: str = None, + virtual: bool = False, + ) -> Type[object]: + fk_string = to.Meta.tablename + "." + to.Meta.pkname + to_field = to.__fields__[to.Meta.pkname] + namespace = dict( + to=to, + name=name, + nullable=nullable, + constraints=[sqlalchemy.schema.ForeignKey(fk_string)], + unique=unique, + column_type=to_field.type_.column_type, + related_name=related_name, + virtual=virtual, + primary_key=False, + index=False, + pydantic_only=False, + default=None, + server_default=None + ) + + return type("ForeignKey", (ForeignKeyField, BaseField), namespace) + + +class ForeignKeyField(BaseField): + to: Type["Model"] + related_name: str + virtual: bool + + @classmethod + def __get_validators__(cls) -> Callable: + yield cls.validate + + @classmethod + def validate(cls, v: Any) -> Any: + return v + + @property + def __type__(self) -> Type[BaseModel]: + return self.to.__pydantic_model__ + + @classmethod + def get_column_type(cls) -> sqlalchemy.Column: + to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] + return to_column.column_type + + @classmethod + def _extract_model_from_sequence( + cls, value: List, child: "Model" + ) -> Union["Model", List["Model"]]: + return [cls.expand_relationship(val, child) for val in value] + + @classmethod + def _register_existing_model(cls, value: "Model", child: "Model") -> "Model": + cls.register_relation(value, child) + return value + + @classmethod + def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model": + model = cls.to(**value) + cls.register_relation(model, child) + return model + + @classmethod + def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model": + if not isinstance(value, cls.to.pk_type()): + raise RelationshipInstanceError( + f"Relationship error - ForeignKey {cls.to.__name__} " + f"is of type {cls.to.pk_type()} " + f"while {type(value)} passed as a parameter." + ) + model = create_dummy_instance(fk=cls.to, pk=value) + cls.register_relation(model, child) + return model + + @classmethod + def register_relation(cls, model: "Model", child: "Model") -> None: + child_model_name = cls.related_name or child.get_name() + model.Meta._orm_relationship_manager.add_relation( + model, child, child_model_name, virtual=cls.virtual + ) + + @classmethod + def expand_relationship( + cls, value: Any, child: "Model" + ) -> Optional[Union["Model", List["Model"]]]: + print("expandong relatiknship", value, child) + if value is None: + return None + + constructors = { + f"{cls.to.__name__}": cls._register_existing_model, + "dict": cls._construct_model_from_dict, + "list": cls._extract_model_from_sequence, + } + + model = constructors.get( + value.__class__.__name__, cls._construct_model_from_pk + )(value, child) + return model diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py new file mode 100644 index 0000000..d790d65 --- /dev/null +++ b/ormar/fields/model_fields.py @@ -0,0 +1,373 @@ +import datetime +import decimal +import re +from typing import Type, Any, Optional + +import pydantic +import sqlalchemy +from pydantic import Json + +from ormar import ModelDefinitionError +from ormar.fields.base import BaseField # noqa I101 + + +def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool: + if nullable is None: + return default is not None or server_default is not None + return nullable + + +def String( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[str]: + if max_length is None or max_length <= 0: + raise ModelDefinitionError(f'Parameter max_length is required for field String') + + namespace = dict( + __type__=str, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + allow_blank=allow_blank, + strip_whitespace=strip_whitespace, + min_length=min_length, + max_length=max_length, + curtail_length=curtail_length, + regex=regex and re.compile(regex), + column_type=sqlalchemy.String(length=max_length), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + + return type("String", (pydantic.ConstrainedStr, BaseField), namespace) + + +def Integer( + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=int, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.Integer(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=autoincrement if autoincrement is not None else primary_key + ) + return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) + + +def Text( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[str]: + namespace = dict( + __type__=str, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + allow_blank=allow_blank, + strip_whitespace=strip_whitespace, + column_type=sqlalchemy.Text(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + + return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) + + +def Float( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=float, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.Float(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) + + +def Boolean( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[bool]: + namespace = dict( + __type__=bool, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Boolean(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Boolean", (int, BaseField), namespace) + + +def DateTime( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.datetime]: + namespace = dict( + __type__=datetime.datetime, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.DateTime(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("DateTime", (datetime.datetime, BaseField), namespace) + + +def Date( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.date]: + namespace = dict( + __type__=datetime.date, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Date(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Date", (datetime.date, BaseField), namespace) + + +def Time( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.time]: + namespace = dict( + __type__=datetime.time, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Time(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Time", (datetime.time, BaseField), namespace) + + +def JSON( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[Json]: + namespace = dict( + __type__=pydantic.Json, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.JSON(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + + return type("JSON", (pydantic.Json, BaseField), namespace) + + +def BigInteger( + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=int, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.BigInteger(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=autoincrement if autoincrement is not None else primary_key + ) + return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) + + +def Decimal( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +): + if precision is None or precision < 0 or scale is None or scale < 0: + raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal') + + namespace = dict( + __type__=decimal.Decimal, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale), + precision=precision, + scale=scale, + max_digits=max_digits, + decimal_places=decimal_places, + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) diff --git a/ormar/models/__init__.py b/ormar/models/__init__.py new file mode 100644 index 0000000..b70c515 --- /dev/null +++ b/ormar/models/__init__.py @@ -0,0 +1,4 @@ +from ormar.models.fakepydantic import FakePydantic +from ormar.models.model import Model + +__all__ = ["FakePydantic", "Model"] diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py new file mode 100644 index 0000000..44dd97f --- /dev/null +++ b/ormar/models/fakepydantic.py @@ -0,0 +1,290 @@ +import inspect +import json +import uuid +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Type, + TypeVar, + Union, AbstractSet, Mapping, +) + +import databases +import pydantic +import sqlalchemy +from pydantic import BaseModel + +import ormar # noqa I100 +from ormar import ForeignKey +from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField +from ormar.models.metaclass import ModelMetaclass, ModelMeta +from ormar.relations import RelationshipManager + +if TYPE_CHECKING: # pragma no cover + from ormar.models.model import Model + + IntStr = Union[int, str] + DictStrAny = Dict[str, Any] + AbstractSetIntStr = AbstractSet[IntStr] + MappingIntStrAny = Mapping[IntStr, Any] + + +class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): + # FakePydantic inherits from list in order to be treated as + # request.Body parameter in fastapi routes, + # inheriting from pydantic.BaseModel causes metaclass conflicts + __slots__ = ('_orm_id', '_orm_saved') + + if TYPE_CHECKING: # pragma no cover + __model_fields__: Dict[str, TypeVar[BaseField]] + __table__: sqlalchemy.Table + __fields__: Dict[str, pydantic.fields.ModelField] + __pydantic_model__: Type[BaseModel] + __pkname__: str + __tablename__: str + __metadata__: sqlalchemy.MetaData + __database__: databases.Database + _orm_relationship_manager: RelationshipManager + Meta: ModelMeta + + # noinspection PyMissingConstructor + def __init__(self, *args: Any, **kwargs: Any) -> None: + + object.__setattr__(self, "_orm_id", uuid.uuid4().hex) + object.__setattr__(self, "_orm_saved", False) + + pk_only = kwargs.pop("__pk_only__", False) + if "pk" in kwargs: + kwargs[self.Meta.pkname] = kwargs.pop("pk") + kwargs = { + k: self.Meta.model_fields[k].expand_relationship(v, self) + for k, v in kwargs.items() + } + + values, fields_set, validation_error = pydantic.validate_model( + self, kwargs + ) + if validation_error and not pk_only: + raise validation_error + + object.__setattr__(self, '__dict__', values) + object.__setattr__(self, '__fields_set__', fields_set) + + # super().__init__(**kwargs) + # self.values = self.__pydantic_model__(**kwargs) + + def __del__(self) -> None: + self.Meta._orm_relationship_manager.deregister(self) + + def __setattr__(self, name, value): + relation_key = self.get_name(title=True) + "_" + name + if name in self.__slots__: + object.__setattr__(self, name, value) + elif name == 'pk': + object.__setattr__(self, self.Meta.pkname, value) + elif self.Meta._orm_relationship_manager.contains(relation_key, self): + self.Meta.model_fields[name].expand_relationship(value, self) + else: + super().__setattr__(name, value) + + def __getattr__(self, item): + relation_key = self.get_name(title=True) + "_" + item + if self.Meta._orm_relationship_manager.contains(relation_key, self): + return self.Meta._orm_relationship_manager.get(relation_key, self) + + # def __setattr__(self, key: str, value: Any) -> None: + # if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'): + # return setattr(self, key, value) + # # elif key in self._extract_related_names(): + # # value = self._convert_json(key, value, op="dumps") + # # value = self.Meta.model_fields[key].expand_relationship(value, self) + # # relation_key = self.get_name(title=True) + "_" + key + # # if not self.Meta._orm_relationship_manager.contains(relation_key, self): + # # setattr(self.values, key, value) + # else: + # super().__setattr__(key, value) + + # def __getattribute__(self, key: str) -> Any: + # if key != 'Meta' and key in self.Meta.model_fields: + # relation_key = self.get_name(title=True) + "_" + key + # if self.Meta._orm_relationship_manager.contains(relation_key, self): + # return self.Meta._orm_relationship_manager.get(relation_key, self) + # item = getattr(self.__fields__, key, None) + # item = self._convert_json(key, item, op="loads") + # return item + # return super().__getattribute__(key) + + def __same__(self, other: "Model") -> bool: + if self.__class__ != other.__class__: # pragma no cover + return False + return (self._orm_id == other._orm_id or + self.__dict__ == other.__dict__ or + (self.pk == other.pk and self.pk is not None + )) + + # def __repr__(self) -> str: # pragma no cover + # return self.values.__repr__() + + # @classmethod + # def __get_validators__(cls) -> Callable: # pragma no cover + # yield cls.__pydantic_model__.validate + + @classmethod + def get_name(cls, title: bool = False, lower: bool = True) -> str: + name = cls.__name__ + if lower: + name = name.lower() + if title: + name = name.title() + return name + + @property + def pk(self) -> Any: + return getattr(self, self.Meta.pkname) + + @pk.setter + def pk(self, value: Any) -> None: + setattr(self, self.Meta.pkname, value) + + @property + def pk_column(self) -> sqlalchemy.Column: + return self.Meta.table.primary_key.columns.values()[0] + + @classmethod + def pk_type(cls) -> Any: + return cls.Meta.model_fields[cls.Meta.pkname].__type__ + + def dict( + self, + *, + include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, + exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False + ) -> 'DictStrAny': # noqa: A003 + print('callin super', self.__class__) + print('to exclude', self._exclude_related_names_not_required(nested)) + dict_instance = super().dict(include=include, + exclude=self._exclude_related_names_not_required(nested), + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + print('after super') + for field in self._extract_related_names(): + print(self.__class__, field, nested) + nested_model = getattr(self, field) + + if self.Meta.model_fields[field].virtual and nested: + continue + if isinstance(nested_model, list) and not isinstance( + nested_model, ormar.Model + ): + print('nested list') + dict_instance[field] = [x.dict(nested=True) for x in nested_model] + else: + print('instance') + if nested_model is not None: + dict_instance[field] = nested_model.dict(nested=True) + + return dict_instance + + def from_dict(self, value_dict: Dict) -> None: + for key, value in value_dict.items(): + setattr(self, key, value) + + def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: + + if not self._is_conversion_to_json_needed(column_name): + return value + + condition = ( + isinstance(value, str) if op == "loads" else not isinstance(value, str) + ) + operand = json.loads if op == "loads" else json.dumps + + if condition: + try: + return operand(value) + except TypeError: # pragma no cover + pass + return value + + def _is_conversion_to_json_needed(self, column_name: str) -> bool: + return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json + + def _extract_own_model_fields(self) -> Dict: + related_names = self._extract_related_names() + self_fields = {k: v for k, v in self.dict().items() if k not in related_names} + return self_fields + + @classmethod + def _extract_related_names(cls) -> Set: + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass( + field, ForeignKeyField + ): + related_names.add(name) + return related_names + + @classmethod + def _exclude_related_names_not_required(cls, nested:bool=False) -> Set: + if nested: + return cls._extract_related_names() + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable: + related_names.add(name) + return related_names + + def _extract_model_db_fields(self) -> Dict: + self_fields = self._extract_own_model_fields() + self_fields = { + k: v for k, v in self_fields.items() if k in self.Meta.table.columns + } + for field in self._extract_related_names(): + if getattr(self, field) is not None: + self_fields[field] = getattr( + getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname + ) + return self_fields + + @classmethod + def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: + merged_rows = [] + for index, model in enumerate(result_rows): + if index > 0 and model.pk == result_rows[index - 1].pk: + result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + else: + merged_rows.append(model) + return merged_rows + + @classmethod + def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": + for field in one.Meta.model_fields.keys(): + if isinstance(getattr(one, field), list) and not isinstance( + getattr(one, field), ormar.Model + ): + setattr(other, field, getattr(one, field) + getattr(other, field)) + elif isinstance(getattr(one, field), ormar.Model): + if getattr(one, field).pk == getattr(other, field).pk: + setattr( + other, + field, + cls.merge_two_instances( + getattr(one, field), getattr(other, field) + ), + ) + return other diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py new file mode 100644 index 0000000..1d4044b --- /dev/null +++ b/ormar/models/metaclass.py @@ -0,0 +1,196 @@ +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union + +import databases +import pydantic +import sqlalchemy +from pydantic import BaseConfig, create_model, Extra +from pydantic.fields import ModelField, FieldInfo + +from ormar import ForeignKey, ModelDefinitionError # noqa I100 +from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField +from ormar.queryset import QuerySet +from ormar.relations import RelationshipManager + +if TYPE_CHECKING: # pragma no cover + from ormar import Model + +relationship_manager = RelationshipManager() + + +class ModelMeta: + tablename: str + table: sqlalchemy.Table + metadata: sqlalchemy.MetaData + database: databases.Database + columns: List[sqlalchemy.Column] + pkname: str + model_fields: Dict[str, Union[BaseField, ForeignKey]] + _orm_relationship_manager: RelationshipManager + + +def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: + pydantic_fields = { + field_name: ( + base_field.__type__, + ... if base_field.is_required else base_field.default_value, + ) + for field_name, base_field in object_dict.items() + if isinstance(base_field, BaseField) + } + return pydantic_fields + + +def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: + child_relation_name = ( + field.to.get_name(title=True) + + "_" + + (field.related_name or (name.lower() + "s")) + ) + reverse_name = child_relation_name + relation_name = name.lower().title() + "_" + field.to.get_name() + relationship_manager.add_relation_type( + relation_name, reverse_name, field, table_name + ) + + +def expand_reverse_relationships(model: Type["Model"]) -> None: + for model_field in model.Meta.model_fields.values(): + if issubclass(model_field, ForeignKeyField): + child_model_name = model_field.related_name or model.get_name() + "s" + parent_model = model_field.to + child = model + if ( + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ + ): + register_reverse_model_fields(parent_model, child, child_model_name) + + +def register_reverse_model_fields( + model: Type["Model"], child: Type["Model"], child_model_name: str +) -> None: + # model.__fields__[child_model_name] = ModelField( + # name=child_model_name, + # type_=Optional[Union[List[child], child]], + # model_config=child.__config__, + # class_validators=child.__validators__, + # ) + model.Meta.model_fields[child_model_name] = ForeignKey( + child, name=child_model_name, virtual=True + ) + + +def sqlalchemy_columns_from_model_fields( + name: str, object_dict: Dict, table_name: str +) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: + columns = [] + pkname = None + model_fields = { + field_name: field + for field_name, field in object_dict['__annotations__'].items() + if issubclass(field, BaseField) + } + for field_name, field in model_fields.items(): + if field.primary_key: + if pkname is not None: + raise ModelDefinitionError("Only one primary key column is allowed.") + if field.pydantic_only: + raise ModelDefinitionError('Primary key column cannot be pydantic only') + pkname = field_name + if not field.pydantic_only: + columns.append(field.get_column(field_name)) + if issubclass(field, ForeignKeyField): + register_relation_on_build(table_name, field, name) + + return pkname, columns, model_fields + + +def populate_pydantic_default_values(attrs: Dict) -> Dict: + for field, type_ in attrs["__annotations__"].items(): + if issubclass(type_, BaseField): + if type_.name is None: + type_.name = field + def_value = type_.default_value() + curr_def_value = attrs.get(field, 'NONE') + print(field, curr_def_value, 'def val', type_.nullable) + if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo): + attrs[field] = def_value + elif curr_def_value == 'NONE' and type_.nullable: + print(field, 'defsults tp none') + attrs[field] = FieldInfo(default=None) + return attrs + + +def get_pydantic_base_orm_config() -> Type[BaseConfig]: + class Config(BaseConfig): + orm_mode = True + arbitrary_types_allowed = True + # extra = Extra.allow + + return Config + + +class ModelMetaclass(pydantic.main.ModelMetaclass): + def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: + + attrs['Config'] = get_pydantic_base_orm_config() + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + + if hasattr(new_model, 'Meta'): + + if attrs.get("__abstract__"): + return new_model + + annotations = attrs.get("__annotations__") or new_model.__annotations__ + attrs["__annotations__"]= annotations + attrs = populate_pydantic_default_values(attrs) + + print(attrs) + + tablename = name.lower() + "s" + new_model.Meta.tablename = new_model.Meta.tablename or tablename + + # sqlalchemy table creation + + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, new_model.Meta.tablename + ) + + if hasattr(new_model.Meta, "model_fields") and not pkname: + model_fields = new_model.Meta.model_fields + for fieldname, field in new_model.Meta.model_fields.items(): + if field.primary_key: + pkname=fieldname + columns = new_model.Meta.table.columns + + if not hasattr(new_model.Meta, "table"): + new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns) + + new_model.Meta.columns = columns + new_model.Meta.pkname = pkname + + if not pkname: + raise ModelDefinitionError("Table has to have a primary key.") + + # pydantic model creation + new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs) + new_model.Meta.pydantic_model = create_model( + name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields + ) + + new_model.Meta.model_fields = model_fields + print(attrs, 'before super') + print(new_model.Meta.__dict__) + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + expand_reverse_relationships(new_model) + + new_model.Meta._orm_relationship_manager = relationship_manager + new_model.objects = QuerySet(new_model) + + # breakpoint() + return new_model diff --git a/ormar/models/model.py b/ormar/models/model.py new file mode 100644 index 0000000..4651aa1 --- /dev/null +++ b/ormar/models/model.py @@ -0,0 +1,85 @@ +from typing import Any, List + +import sqlalchemy + +import ormar.queryset # noqa I100 +from ormar.models import FakePydantic # noqa I100 + + +class Model(FakePydantic): + __abstract__ = False + + # objects = ormar.queryset.QuerySet() + + @classmethod + def from_row( + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, + ) -> "Model": + + item = {} + select_related = select_related or [] + + table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join( + previous_table, cls.Meta.table.name + ) + previous_table = cls.Meta.table.name + for related in select_related: + if "__" in related: + first_part, remainder = related.split("__", 1) + model_cls = cls.Meta.model_fields[first_part].to + child = model_cls.from_row( + row, select_related=[remainder], previous_table=previous_table + ) + item[first_part] = child + else: + model_cls = cls.Meta.model_fields[related].to + child = model_cls.from_row(row, previous_table=previous_table) + item[related] = child + + for column in cls.Meta.table.columns: + if column.name not in item: + item[column.name] = row[ + f'{table_prefix + "_" if table_prefix else ""}{column.name}' + ] + + return cls(**item) + + async def save(self) -> "Model": + self_fields = self._extract_model_db_fields() + if self.Meta.model_fields.get(self.Meta.pkname).autoincrement: + self_fields.pop(self.Meta.pkname, None) + expr = self.Meta.table.insert() + expr = expr.values(**self_fields) + item_id = await self.Meta.database.execute(expr) + setattr(self, self.Meta.pkname, item_id) + return self + + async def update(self, **kwargs: Any) -> int: + if kwargs: + new_values = {**self.dict(), **kwargs} + self.from_dict(new_values) + + self_fields = self._extract_model_db_fields() + self_fields.pop(self.Meta.pkname) + expr = ( + self.Meta.table.update() + .values(**self_fields) + .where(self.pk_column == getattr(self, self.Meta.pkname)) + ) + result = await self.Meta.database.execute(expr) + return result + + async def delete(self) -> int: + expr = self.Meta.table.delete() + expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) + result = await self.Meta.database.execute(expr) + return result + + async def load(self) -> "Model": + expr = self.Meta.table.select().where(self.pk_column == self.pk) + row = await self.Meta.database.fetch_one(expr) + self.from_dict(dict(row)) + return self diff --git a/ormar/queryset/__init__.py b/ormar/queryset/__init__.py new file mode 100644 index 0000000..7bf6fc6 --- /dev/null +++ b/ormar/queryset/__init__.py @@ -0,0 +1,3 @@ +from ormar.queryset.queryset import QuerySet + +__all__ = ["QuerySet"] diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py new file mode 100644 index 0000000..dc94e6f --- /dev/null +++ b/ormar/queryset/clause.py @@ -0,0 +1,178 @@ +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union + +import sqlalchemy +from sqlalchemy import text + +import ormar # noqa I100 +from ormar.exceptions import QueryDefinitionError + +if TYPE_CHECKING: # pragma no cover + from ormar import Model + +FILTER_OPERATORS = { + "exact": "__eq__", + "iexact": "ilike", + "contains": "like", + "icontains": "ilike", + "in": "in_", + "gt": "__gt__", + "gte": "__ge__", + "lt": "__lt__", + "lte": "__le__", +} +ESCAPE_CHARACTERS = ["%", "_"] + + +class QueryClause: + def __init__( + self, model_cls: Type["Model"], filter_clauses: List, select_related: List, + ) -> None: + + self._select_related = select_related + self.filter_clauses = filter_clauses + + self.model_cls = model_cls + self.table = self.model_cls.Meta.table + + def filter( # noqa: A003 + self, **kwargs: Any + ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + + if kwargs.get("pk"): + pk_name = self.model_cls.Meta.pkname + kwargs[pk_name] = kwargs.pop("pk") + + for key, value in kwargs.items(): + table_prefix = "" + if "__" in key: + parts = key.split("__") + + ( + op, + field_name, + related_parts, + ) = self._extract_operator_field_and_related(parts) + + model_cls = self.model_cls + if related_parts: + ( + select_related, + table_prefix, + model_cls, + ) = self._determine_filter_target_table( + related_parts, select_related + ) + + table = model_cls.Meta.table + column = model_cls.Meta.table.columns[field_name] + + else: + op = "exact" + column = self.table.columns[key] + table = self.table + + value, has_escaped_character = self._escape_characters_in_clause(op, value) + + if isinstance(value, ormar.Model): + value = value.pk + + op_attr = FILTER_OPERATORS[op] + clause = getattr(column, op_attr)(value) + clause = self._compile_clause( + clause, + column, + table, + table_prefix, + modifiers={"escape": "\\" if has_escaped_character else None}, + ) + filter_clauses.append(clause) + + return filter_clauses, select_related + + def _determine_filter_target_table( + self, related_parts: List[str], select_related: List[str] + ) -> Tuple[List[str], str, "Model"]: + + table_prefix = "" + model_cls = self.model_cls + select_related = [relation for relation in select_related] + + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + previous_table = model_cls.Meta.tablename + for part in related_parts: + current_table = model_cls.Meta.model_fields[part].to.Meta.tablename + manager = model_cls.Meta._orm_relationship_manager + table_prefix = manager.resolve_relation_join(previous_table, current_table) + model_cls = model_cls.Meta.model_fields[part].to + previous_table = current_table + return select_related, table_prefix, model_cls + + def _compile_clause( + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, + ) -> sqlalchemy.sql.expression.TextClause: + for modifier, modifier_value in modifiers.items(): + clause.modifiers[modifier] = modifier_value + + clause_text = str( + clause.compile( + dialect=self.model_cls.Meta.database._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) + clause = text(clause_text) + return clause + + @staticmethod + def _escape_characters_in_clause( + op: str, value: Union[str, "Model"] + ) -> Tuple[str, bool]: + has_escaped_character = False + + if op not in ["contains", "icontains"]: + return value, has_escaped_character + + if isinstance(value, ormar.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) + + has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) + + if has_escaped_character: + # enable escape modifier + for char in ESCAPE_CHARACTERS: + value = value.replace(char, f"\\{char}") + value = f"%{value}%" + + return value, has_escaped_character + + @staticmethod + def _extract_operator_field_and_related( + parts: List[str], + ) -> Tuple[str, str, Optional[List]]: + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + return op, field_name, related_parts diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py new file mode 100644 index 0000000..798502a --- /dev/null +++ b/ormar/queryset/query.py @@ -0,0 +1,236 @@ +from typing import List, NamedTuple, TYPE_CHECKING, Tuple, Type + +import sqlalchemy +from sqlalchemy import text + +import ormar # noqa I100 +from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField + +if TYPE_CHECKING: # pragma no cover + from ormar import Model + + +class JoinParameters(NamedTuple): + prev_model: Type["Model"] + previous_alias: str + from_table: str + model_cls: Type["Model"] + + +class Query: + def __init__( + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, + ) -> None: + + self.query_offset = offset + self.limit_count = limit_count + self._select_related = select_related + self.filter_clauses = filter_clauses + + self.model_cls = model_cls + self.table = self.model_cls.Meta.table + + self.auto_related = [] + self.used_aliases = [] + self.already_checked = [] + + self.select_from = None + self.columns = None + self.order_bys = None + + def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: + self.columns = list(self.table.columns) + self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] + self.select_from = self.table + + for key in self.model_cls.Meta.model_fields: + if ( + not self.model_cls.Meta.model_fields[key].nullable + and isinstance( + self.model_cls.Meta.model_fields[key], ForeignKeyField, + ) + and key not in self._select_related + ): + self._select_related = [key] + self._select_related + + start_params = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + self._extract_auto_required_relations(prev_model=start_params.prev_model) + self._include_auto_related_models() + self._select_related.sort(key=lambda item: (-len(item), item)) + + for item in self._select_related: + join_parameters = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + + for part in item.split("__"): + join_parameters = self._build_join_parameters(part, join_parameters) + + expr = sqlalchemy.sql.select(self.columns) + expr = expr.select_from(self.select_from) + + expr = self._apply_expression_modifiers(expr) + + print(expr.compile(compile_kwargs={"literal_binds": True})) + self._reset_query_parameters() + + return expr, self._select_related + + @staticmethod + def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: + return [ + text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") + for column in table.columns + ] + + @staticmethod + def prefixed_table_name(alias: str, name: str) -> text: + return text(f"{name} {alias}_{name}") + + @staticmethod + def _field_is_a_foreign_key_and_no_circular_reference( + field: BaseField, field_name: str, rel_part: str + ) -> bool: + return issubclass(field, ForeignKeyField) and field_name not in rel_part + + def _field_qualifies_to_deeper_search( + self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str + ) -> bool: + prev_part_of_related = "__".join(rel_part.split("__")[:-1]) + partial_match = any( + [x.startswith(prev_part_of_related) for x in self._select_related] + ) + already_checked = any( + [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] + ) + return ( + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested + + def on_clause( + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + ) -> text: + left_part = f"{alias}_{to_clause}" + right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" + return text(f"{left_part}={right_part}") + + def _build_join_parameters( + self, part: str, join_params: JoinParameters + ) -> JoinParameters: + model_cls = join_params.model_cls.Meta.model_fields[part].to + to_table = model_cls.Meta.table.name + + alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join( + join_params.from_table, to_table + ) + if alias not in self.used_aliases: + if join_params.prev_model.Meta.model_fields[part].virtual: + to_key = next( + ( + v + for k, v in model_cls.Meta.model_fields.items() + if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model + ), + None, + ).name + from_key = model_cls.Meta.pkname + else: + to_key = model_cls.Meta.pkname + from_key = part + + on_clause = self.on_clause( + previous_alias=join_params.previous_alias, + alias=alias, + from_clause=f"{join_params.from_table}.{from_key}", + to_clause=f"{to_table}.{to_key}", + ) + target_table = self.prefixed_table_name(alias, to_table) + self.select_from = sqlalchemy.sql.outerjoin( + self.select_from, target_table, on_clause + ) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) + self.columns.extend(self.prefixed_columns(alias, model_cls.Meta.table)) + self.used_aliases.append(alias) + + previous_alias = alias + from_table = to_table + prev_model = model_cls + return JoinParameters(prev_model, previous_alias, from_table, model_cls) + + def _extract_auto_required_relations( + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, + ) -> None: + for field_name, field in prev_model.Meta.model_fields.items(): + if self._field_is_a_foreign_key_and_no_circular_reference( + field, field_name, rel_part + ): + rel_part = field_name if not rel_part else rel_part + "__" + field_name + if not field.nullable: + print('add', rel_part, field) + if rel_part not in self._select_related: + new_related = "__".join(rel_part.split("__")[:-1]) if len( + rel_part.split("__")) > 1 else rel_part + self.auto_related.append(new_related) + rel_part = "" + elif self._field_qualifies_to_deeper_search( + field, parent_virtual, nested, rel_part + ): + print('deeper', rel_part, field, field.to) + self._extract_auto_required_relations( + prev_model=field.to, + rel_part=rel_part, + nested=True, + parent_virtual=field.virtual, + ) + else: + self.already_checked.append(rel_part) + rel_part = "" + + def _include_auto_related_models(self) -> None: + if self.auto_related: + new_joins = [] + for join in self._select_related: + if not any([x.startswith(join) for x in self.auto_related]): + new_joins.append(join) + self._select_related = new_joins + self.auto_related + + def _apply_expression_modifiers( + self, expr: sqlalchemy.sql.select + ) -> sqlalchemy.sql.select: + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + for order in self.order_bys: + expr = expr.order_by(order) + return expr + + def _reset_query_parameters(self) -> None: + self.select_from = None + self.columns = None + self.order_bys = None + self.auto_related = [] + self.used_aliases = [] + self.already_checked = [] diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py new file mode 100644 index 0000000..4e5d85e --- /dev/null +++ b/ormar/queryset/queryset.py @@ -0,0 +1,181 @@ +from typing import Any, List, TYPE_CHECKING, Tuple, Type, Union + +import databases +import sqlalchemy + +import ormar # noqa I100 +from ormar import MultipleMatches, NoMatch +from ormar.queryset.clause import QueryClause +from ormar.queryset.query import Query + +if TYPE_CHECKING: # pragma no cover + from ormar import Model + + +class QuerySet: + def __init__( + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + ) -> None: + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + self.order_bys = None + + def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": + return self.__class__(model_cls=owner) + + @property + def database(self) -> databases.Database: + return self.model_cls.Meta.database + + @property + def table(self) -> sqlalchemy.Table: + return self.model_cls.Meta.table + + def build_select_expression(self) -> sqlalchemy.sql.select: + qry = Query( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + offset=self.query_offset, + limit_count=self.limit_count, + ) + exp, self._select_related = qry.build_select_expression() + return exp + + def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + qryclause = QueryClause( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + ) + filter_clauses, select_related = qryclause.filter(**kwargs) + + return self.__class__( + model_cls=self.model_cls, + filter_clauses=filter_clauses, + select_related=select_related, + limit_count=self.limit_count, + offset=self.query_offset, + ) + + def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": + if not isinstance(related, (list, tuple)): + related = [related] + + related = list(self._select_related) + related + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=related, + limit_count=self.limit_count, + offset=self.query_offset, + ) + + async def exists(self) -> bool: + expr = self.build_select_expression() + expr = sqlalchemy.exists(expr).select() + return await self.database.fetch_val(expr) + + async def count(self) -> int: + expr = self.build_select_expression().alias("subquery_for_count") + expr = sqlalchemy.func.count().select().select_from(expr) + return await self.database.fetch_val(expr) + + def limit(self, limit_count: int) -> "QuerySet": + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=limit_count, + offset=self.query_offset, + ) + + def offset(self, offset: int) -> "QuerySet": + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=offset, + ) + + async def first(self, **kwargs: Any) -> "Model": + if kwargs: + return await self.filter(**kwargs).first() + + rows = await self.limit(1).all() + if rows: + return rows[0] + + async def get(self, **kwargs: Any) -> "Model": + if kwargs: + return await self.filter(**kwargs).get() + + expr = self.build_select_expression().limit(2) + rows = await self.database.fetch_all(expr) + + if not rows: + raise NoMatch() + if len(rows) > 1: + raise MultipleMatches() + return self.model_cls.from_row(rows[0], select_related=self._select_related) + + async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 + if kwargs: + return await self.filter(**kwargs).all() + + expr = self.build_select_expression() + rows = await self.database.fetch_all(expr) + result_rows = [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + + result_rows = self.model_cls.merge_instances_list(result_rows) + + return result_rows + + async def create(self, **kwargs: Any) -> "Model": + + new_kwargs = dict(**kwargs) + + # Remove primary key when None to prevent not null constraint in postgresql. + pkname = self.model_cls.Meta.pkname + pk = self.model_cls.Meta.model_fields[pkname] + if ( + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) + ): + del new_kwargs[pkname] + + # substitute related models with their pk + for field in self.model_cls._extract_related_names(): + if field in new_kwargs and new_kwargs.get(field) is not None: + if isinstance(new_kwargs.get(field), ormar.Model): + new_kwargs[field] = getattr( + new_kwargs.get(field), + self.model_cls.Meta.model_fields[field].to.Meta.pkname, + ) + else: + new_kwargs[field] = new_kwargs.get(field).get( + self.model_cls.Meta.model_fields[field].to.Meta.pkname + ) + + # Build the insert expression. + expr = self.table.insert() + expr = expr.values(**new_kwargs) + + # Execute the insert, and return a new model instance. + instance = self.model_cls(**kwargs) + pk = await self.database.execute(expr) + setattr(instance, self.model_cls.Meta.pkname, pk) + return instance diff --git a/ormar/relations.py b/ormar/relations.py new file mode 100644 index 0000000..2177ac0 --- /dev/null +++ b/ormar/relations.py @@ -0,0 +1,104 @@ +import pprint +import string +import uuid +from random import choices +from typing import List, TYPE_CHECKING, Union +from weakref import proxy + +from ormar import ForeignKey +from ormar.fields.foreign_key import ForeignKeyField + +if TYPE_CHECKING: # pragma no cover + from ormar.models import FakePydantic, Model + + +def get_table_alias() -> str: + return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] + + +class RelationshipManager: + def __init__(self) -> None: + self._relations = dict() + self._aliases = dict() + + def add_relation_type( + self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str + ) -> None: + if relations_key not in self._relations: + self._relations[relations_key] = {"type": "primary"} + self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() + if reverse_key not in self._relations: + self._relations[reverse_key] = {"type": "reverse"} + self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() + + def deregister(self, model: "FakePydantic") -> None: + for rel_type in self._relations.keys(): + if model.get_name() in rel_type.lower(): + if model._orm_id in self._relations[rel_type]: + del self._relations[rel_type][model._orm_id] + + def add_relation( + self, + parent: "FakePydantic", + child: "FakePydantic", + child_model_name: str, + virtual: bool = False, + ) -> None: + parent_id, child_id = parent._orm_id, child._orm_id + parent_name = parent.get_name(title=True) + child_name = ( + child_model_name + if child.get_name() != child_model_name + else child.get_name() + "s" + ) + if virtual: + child_name, parent_name = parent_name, child.get_name() + child_id, parent_id = parent_id, child_id + child, parent = parent, proxy(child) + child_name = child_name.lower() + "s" + else: + child = proxy(child) + + parent_relation_name = parent_name.title() + "_" + child_name + parents_list = self._relations[parent_relation_name].setdefault(parent_id, []) + self.append_related_model(parents_list, child) + + child_relation_name = child.get_name(title=True) + "_" + parent_name.lower() + children_list = self._relations[child_relation_name].setdefault(child_id, []) + self.append_related_model(children_list, parent) + + @staticmethod + def append_related_model(relations_list: List["Model"], model: "Model") -> None: + print("appending", relations_list, model) + for relation_child in relations_list: + try: + print(relation_child.__same__(model), "same") + if relation_child.__same__(model): + return + except ReferenceError: + continue + + relations_list.append(model) + + def contains(self, relations_key: str, instance: "FakePydantic") -> bool: + if relations_key in self._relations: + return instance._orm_id in self._relations[relations_key] + return False + + def get( + self, relations_key: str, instance: "FakePydantic" + ) -> Union["Model", List["Model"]]: + if relations_key in self._relations: + if instance._orm_id in self._relations[relations_key]: + if self._relations[relations_key]["type"] == "primary": + return self._relations[relations_key][instance._orm_id][0] + return self._relations[relations_key][instance._orm_id] + + def resolve_relation_join(self, from_table: str, to_table: str) -> str: + return self._aliases.get(f"{from_table}_{to_table}", "") + + def __str__(self) -> str: # pragma no cover + return pprint.pformat(self._relations, indent=4, width=1) + + def __repr__(self) -> str: # pragma no cover + return self.__str__() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..807e704 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +databases[sqlite] +pydantic +sqlalchemy + +# Testing +pytest +pytest-cov +codecov +pytest-asyncio +fastapi +flake8 +flake8-black +flake8-bugbear +flake8-import-order +flake8-bandit +flake8-annotations +flake8-builtins +flake8-variables-names +flake8-cognitive-complexity +flake8-functions +flake8-expression-complexity \ No newline at end of file diff --git a/scripts/clean.sh b/scripts/clean.sh new file mode 100644 index 0000000..9b30e14 --- /dev/null +++ b/scripts/clean.sh @@ -0,0 +1,19 @@ +#!/bin/sh -e +PACKAGE="ormar" +if [ -d 'dist' ] ; then + rm -r dist +fi +if [ -d 'site' ] ; then + rm -r site +fi +if [ -d 'htmlcov' ] ; then + rm -r htmlcov +fi +if [ -d "${PACKAGE}.egg-info" ] ; then + rm -r "${PACKAGE}.egg-info" +fi +find ${PACKAGE} -type f -name "*.py[co]" -delete +find ${PACKAGE} -type d -name __pycache__ -delete + +find tests -type f -name "*.py[co]" -delete +find tests -type d -name __pycache__ -delete \ No newline at end of file diff --git a/scripts/publish.sh b/scripts/publish.sh new file mode 100644 index 0000000..419fa30 --- /dev/null +++ b/scripts/publish.sh @@ -0,0 +1,23 @@ +#!/bin/sh -e + +PACKAGE="ormar" + +PREFIX="" +if [ -d 'venv' ] ; then + PREFIX="venv/bin/" +fi + +VERSION=`cat ${PACKAGE}/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"` + +set -x + +scripts/clean.sh + +${PREFIX}python setup.py sdist +${PREFIX}twine upload dist/* + +echo "You probably want to also tag the version now:" +echo "git tag -a ${VERSION} -m 'version ${VERSION}'" +echo "git push --tags" + +scripts/clean.sh \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 0000000..c911ffb --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,12 @@ +#!/bin/sh -e + +PACKAGE="ormar" + +PREFIX="" +if [ -d 'venv' ] ; then + PREFIX="venv/bin/" +fi + +set -x + +PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov=${PACKAGE} --cov=tests --cov-fail-under=100 --cov-report=term-missing "${@}" diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..224a779 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[metadata] +description-file = README.md \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0f768da --- /dev/null +++ b/setup.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import re + +from setuptools import setup + +PACKAGE = "ormar" +URL = "https://github.com/collerek/ormar" + + +def get_version(package): + """ + Return package version as listed in `__version__` in `init.py`. + """ + with open(os.path.join(package, "__init__.py")) as f: + return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) + + +def get_long_description(): + """ + Return the README. + """ + with open("README.md", encoding="utf8") as f: + return f.read() + + +def get_packages(package): + """ + Return root package and all sub-packages. + """ + return [ + dirpath + for dirpath, dirnames, filenames in os.walk(package) + if os.path.exists(os.path.join(dirpath, "__init__.py")) + ] + + +setup( + name=PACKAGE, + version=get_version(PACKAGE), + url=URL, + license="MIT", + description="An simple async ORM with Fastapi in mind.", + long_description=get_long_description(), + long_description_content_type="text/markdown", + keywords=['ORM', 'sqlalchemy', 'fastapi', 'pydantic', 'databases'], + author="collerek", + author_email="collerek@gmail.com", + packages=get_packages(PACKAGE), + package_data={PACKAGE: ["py.typed"]}, + data_files=[("", ["LICENSE.md"])], + install_requires=["databases", "pydantic", "sqlalchemy"], + classifiers=[ + "Development Status :: 3 - Alpha", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Topic :: Internet :: WWW/HTTP", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + ], +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/settings.py b/tests/settings.py new file mode 100644 index 0000000..697acb0 --- /dev/null +++ b/tests/settings.py @@ -0,0 +1,3 @@ +import os + +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") diff --git a/tests/test_columns.py b/tests/test_columns.py new file mode 100644 index 0000000..c8c9d3b --- /dev/null +++ b/tests/test_columns.py @@ -0,0 +1,60 @@ +import datetime + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +def time(): + return datetime.datetime.now().time() + + +class Example(ormar.Model): + class Meta: + tablename = "example" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=200, default='aaa') + created: ormar.DateTime(default=datetime.datetime.now) + created_day: ormar.Date(default=datetime.date.today) + created_time: ormar.Time(default=time) + description: ormar.Text(nullable=True) + value: ormar.Float(nullable=True) + data: ormar.JSON(default={}) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + example = Example() + await example.save() + + await example.load() + assert example.created.year == datetime.datetime.now().year + assert example.created_day == datetime.date.today() + assert example.description is None + assert example.value is None + assert example.data == {} + + await example.update(data={"foo": 123}, value=123.456) + await example.load() + assert example.value == 123.456 + assert example.data == {"foo": 123} + + await example.delete() diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py new file mode 100644 index 0000000..f7f2625 --- /dev/null +++ b/tests/test_fastapi_usage.py @@ -0,0 +1,55 @@ +import databases +import sqlalchemy +from fastapi import FastAPI +from fastapi.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + + +class Item(ormar.Model): + class Meta: + tablename = "items" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) + + +@app.post("/items/", response_model=Item) +async def create_item(item: Item): + return item + + +client = TestClient(app) + + +def test_read_main(): + response = client.post( + "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} + ) + assert response.status_code == 200 + assert response.json() == { + "category": {"id": None, "name": "test cat"}, + "id": 1, + "name": "test", + } + item = Item(**response.json()) + assert item.id == 1 diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py new file mode 100644 index 0000000..e37a3fd --- /dev/null +++ b/tests/test_foreign_keys.py @@ -0,0 +1,296 @@ +import databases +import pytest +import sqlalchemy +from pydantic import ValidationError + +import ormar +from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + + +class Track(ormar.Model): + class Meta: + tablename = "tracks" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album) + title: ormar.String(max_length=100) + position: ormar.Integer() + + +class Cover(ormar.Model): + class Meta: + tablename = "covers" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album, related_name="cover_pictures") + title: ormar.String(max_length=100) + + +class Organisation(ormar.Model): + class Meta: + tablename = "org" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + ident: ormar.String(max_length=100) + + +class Team(ormar.Model): + class Meta: + tablename = "teams" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + org: ormar.ForeignKey(Organisation) + name: ormar.String(max_length=100) + + +class Member(ormar.Model): + class Meta: + tablename = "members" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + team: ormar.ForeignKey(Team) + email: ormar.String(max_length=100) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_wrong_query_foreign_key_type(): + with pytest.raises(RelationshipInstanceError): + Track(title="The Error", album="wrong_pk_type") + + +@pytest.mark.asyncio +async def test_setting_explicitly_empty_relation(): + async with database: + track = Track(album=None, title="The Bird", position=1) + assert track.album is None + + +@pytest.mark.asyncio +async def test_related_name(): + async with database: + album = await Album.objects.create(name="Vanilla") + await Cover.objects.create(album=album, title="The cover file") + + assert len(album.cover_pictures) == 1 + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() + + assert len(album.tracks) == 3 + assert album.tracks[1].title == "Heart don't stand a chance" + + track = await Track.objects.get(title="The Bird") + assert track.album.pk == album.pk + assert track.album.name is None + await track.album.load() + assert track.album.name == "Malibu" + + + album1 = await Album.objects.get(name="Malibu") + assert album1.pk == 1 + assert album1.tracks is None + + +@pytest.mark.asyncio +async def test_select_related(): + async with database: + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() + + fantasies = Album(name="Fantasies") + await fantasies.save() + track4 = Track(album=fantasies, title="Help I'm Alive", position=1) + track5 = Track(album=fantasies, title="Sick Muse", position=2) + track6 = Track(album=fantasies, title="Satellite Mind", position=3) + await track4.save() + await track5.save() + await track6.save() + + track = await Track.objects.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" + + tracks = await Track.objects.select_related("album").all() + assert len(tracks) == 6 + + +@pytest.mark.asyncio +async def test_fk_filter(): + async with database: + malibu = Album(name="Malibu%") + await malibu.save() + await Track.objects.create(album=malibu, title="The Bird", position=1) + await Track.objects.create( + album=malibu, title="Heart don't stand a chance", position=2 + ) + await Track.objects.create(album=malibu, title="The Waters", position=3) + + fantasies = await Album.objects.create(name="Fantasies") + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + + tracks = ( + await Track.objects.select_related("album") + .filter(album__name="Fantasies") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = ( + await Track.objects.select_related("album") + .filter(album__name__icontains="fan") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__contains="fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__contains="Malibu%").all() + assert len(tracks) == 3 + + tracks = await Track.objects.filter(album=malibu).select_related("album").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" + + tracks = await Track.objects.select_related("album").all(album=malibu) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" + + +@pytest.mark.asyncio +async def test_multiple_fk(): + async with database: + acme = await Organisation.objects.create(ident="ACME Ltd") + red_team = await Team.objects.create(org=acme, name="Red Team") + blue_team = await Team.objects.create(org=acme, name="Blue Team") + await Member.objects.create(team=red_team, email="a@example.org") + await Member.objects.create(team=red_team, email="b@example.org") + await Member.objects.create(team=blue_team, email="c@example.org") + await Member.objects.create(team=blue_team, email="d@example.org") + + other = await Organisation.objects.create(ident="Other ltd") + team = await Team.objects.create(org=other, name="Green Team") + await Member.objects.create(team=team, email="e@example.org") + + members = ( + await Member.objects.select_related("team__org") + .filter(team__org__ident="ACME Ltd") + .all() + ) + assert len(members) == 4 + for member in members: + assert member.team.org.ident == "ACME Ltd" + + +@pytest.mark.asyncio +async def test_pk_filter(): + async with database: + fantasies = await Album.objects.create(name="Test") + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + tracks = await Track.objects.select_related("album").filter(pk=1).all() + assert len(tracks) == 1 + + tracks = ( + await Track.objects.select_related("album") + .filter(position=2, album__name="Test") + .all() + ) + assert len(tracks) == 1 + + +@pytest.mark.asyncio +async def test_limit_and_offset(): + async with database: + fantasies = await Album.objects.create(name="Limitless") + await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) + await Track.objects.create(album=fantasies, title="Sample2", position=2) + await Track.objects.create(album=fantasies, title="Sample3", position=3) + + tracks = await Track.objects.limit(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample" + + tracks = await Track.objects.limit(1).offset(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample2" + + +@pytest.mark.asyncio +async def test_get_exceptions(): + async with database: + fantasies = await Album.objects.create(name="Test") + + with pytest.raises(NoMatch): + await Album.objects.get(name="Test2") + + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + with pytest.raises(MultipleMatches): + await Track.objects.select_related("album").get(album=fantasies) + + +@pytest.mark.asyncio +async def test_wrong_model_passed_as_fk(): + with pytest.raises(RelationshipInstanceError): + org = await Organisation.objects.create(ident="ACME Ltd") + await Track.objects.create(album=org, title="Test1", position=1) diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py new file mode 100644 index 0000000..3cfd56c --- /dev/null +++ b/tests/test_model_definition.py @@ -0,0 +1,160 @@ +import datetime +import decimal + +import pydantic +import pytest +import sqlalchemy + +import ormar.fields as fields +from ormar.exceptions import ModelDefinitionError +from ormar.models import Model + +metadata = sqlalchemy.MetaData() + + +class ExampleModel(Model): + class Meta: + tablename = "example" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) + test_text: fields.Text(default="") + test_bool: fields.Boolean(nullable=False) + test_float: fields.Float() = None + test_datetime: fields.DateTime(default=datetime.datetime.now) + test_date: fields.Date(default=datetime.date.today) + test_time: fields.Time(default=datetime.time) + test_json: fields.JSON(default={}) + test_bigint: fields.BigInteger(default=0) + test_decimal: fields.Decimal(scale=10, precision=2) + + +fields_to_check = [ + "test", + "test_text", + "test_string", + "test_datetime", + "test_date", + "test_text", + "test_float", + "test_bigint", + "test_json", +] + + +class ExampleModel2(Model): + class Meta: + tablename = "example2" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) + + +@pytest.fixture() +def example(): + return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)) + + +def test_not_nullable_field_is_required(): + with pytest.raises(pydantic.error_wrappers.ValidationError): + ExampleModel(test=1, test_string="test") + + +def test_model_attribute_access(example): + assert example.test == 1 + assert example.test_string == "test" + assert example.test_datetime.year == datetime.datetime.now().year + assert example.test_date == datetime.date.today() + assert example.test_text == "" + assert example.test_float is None + assert example.test_bigint == 0 + assert example.test_json == {} + + example.test = 12 + assert example.test == 12 + +def test_non_existing_attr(example): + with pytest.raises(ValueError): + example.new_attr=12 + + +def test_primary_key_access_and_setting(example): + assert example.pk == 1 + example.pk = 2 + + assert example.pk == 2 + assert example.test == 2 + + +def test_pydantic_model_is_created(example): + assert issubclass(example.__class__, pydantic.BaseModel) + assert all([field in example.__fields__ for field in fields_to_check]) + assert example.test == 1 + + +def test_sqlalchemy_table_is_created(example): + assert issubclass(example.Meta.table.__class__, sqlalchemy.Table) + assert all([field in example.Meta.table.columns for field in fields_to_check]) + + +def test_no_pk_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + class Meta: + tablename = "example3" + metadata = metadata + + test_string: fields.String(max_length=250) + + +def test_two_pks_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + class Meta: + tablename = "example3" + metadata = metadata + + id: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250, primary_key=True) + + +def test_setting_pk_column_as_pydantic_only_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + class Meta: + tablename = "example4" + metadata = metadata + + test: fields.Integer(primary_key=True, pydantic_only=True) + + +def test_decimal_error_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + class Meta: + tablename = "example5" + metadata = metadata + + test: fields.Decimal(primary_key=True) + + +def test_string_error_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + class Meta: + tablename = "example6" + metadata = metadata + + test: fields.String(primary_key=True) + + +def test_json_conversion_in_model(): + with pytest.raises(pydantic.ValidationError): + ExampleModel( + test_json=datetime.datetime.now(), + test=1, + test_string="test", + test_bool=True, + ) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..42c6b54 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,211 @@ +import databases +import pydantic +import pytest +import sqlalchemy + +import ormar +from ormar.exceptions import QueryDefinitionError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class User(ormar.Model): + class Meta: + tablename = "users" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100, default='') + + +class Product(ormar.Model): + class Meta: + tablename = "product" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + rating: ormar.Integer(minimum=1, maximum=5) + in_stock: ormar.Boolean(default=False) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_model_class(): + assert list(User.Meta.model_fields.keys()) == ["id", "name"] + assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt) + assert User.Meta.model_fields["id"].primary_key is True + assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr) + assert User.Meta.model_fields["name"].max_length == 100 + assert isinstance(User.Meta.table, sqlalchemy.Table) + + +def test_model_pk(): + user = User(pk=1) + assert user.pk == 1 + assert user.id == 1 + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + users = await User.objects.all() + assert users == [] + + user = await User.objects.create(name="Tom") + users = await User.objects.all() + assert user.name == "Tom" + assert user.pk is not None + assert users == [user] + + lookup = await User.objects.get() + assert lookup == user + + await user.update(name="Jane") + users = await User.objects.all() + assert user.name == "Jane" + assert user.pk is not None + assert users == [user] + + await user.delete() + users = await User.objects.all() + assert users == [] + + +@pytest.mark.asyncio +async def test_model_get(): + async with database: + with pytest.raises(ormar.NoMatch): + await User.objects.get() + + user = await User.objects.create(name="Tom") + lookup = await User.objects.get() + assert lookup == user + + user = await User.objects.create(name="Jane") + with pytest.raises(ormar.MultipleMatches): + await User.objects.get() + + same_user = await User.objects.get(pk=user.id) + assert same_user.id == user.id + assert same_user.pk == user.pk + + +@pytest.mark.asyncio +async def test_model_filter(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + user = await User.objects.get(name="Lucy") + assert user.name == "Lucy" + + with pytest.raises(ormar.NoMatch): + await User.objects.get(name="Jim") + + await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) + await Product.objects.create(name="Dress", rating=4) + await Product.objects.create(name="Coat", rating=3, in_stock=True) + + product = await Product.objects.get(name__iexact="t-shirt", rating=5) + assert product.pk is not None + assert product.name == "T-Shirt" + assert product.rating == 5 + + products = await Product.objects.all(rating__gte=2, in_stock=True) + assert len(products) == 2 + + products = await Product.objects.all(name__icontains="T") + assert len(products) == 2 + + # Test escaping % character from icontains, contains, and iexact + await Product.objects.create(name="100%-Cotton", rating=3) + await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) + await Product.objects.create(name="Cotton-100%", rating=3) + products = Product.objects.filter(name__iexact="100%-cotton") + assert await products.count() == 1 + + products = Product.objects.filter(name__contains="%") + assert await products.count() == 3 + + products = Product.objects.filter(name__icontains="%") + assert await products.count() == 3 + + +@pytest.mark.asyncio +async def test_wrong_query_contains_model(): + with pytest.raises(QueryDefinitionError): + product = Product(name="90%-Cotton", rating=2) + await Product.objects.filter(name__contains=product).count() + + +@pytest.mark.asyncio +async def test_model_exists(): + async with database: + await User.objects.create(name="Tom") + assert await User.objects.filter(name="Tom").exists() is True + assert await User.objects.filter(name="Jane").exists() is False + + +@pytest.mark.asyncio +async def test_model_count(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + assert await User.objects.count() == 3 + assert await User.objects.filter(name__icontains="T").count() == 1 + + +@pytest.mark.asyncio +async def test_model_limit(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + assert len(await User.objects.limit(2).all()) == 2 + + +@pytest.mark.asyncio +async def test_model_limit_with_filter(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + + assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 + + +@pytest.mark.asyncio +async def test_offset(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + + users = await User.objects.offset(1).limit(1).all() + assert users[0].name == "Jane" + + +@pytest.mark.asyncio +async def test_model_first(): + async with database: + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + + assert await User.objects.first() == tom + assert await User.objects.first(name="Jane") == jane + assert await User.objects.filter(name="Jane").first() == jane + assert await User.objects.filter(name="Lucy").first() is None diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py new file mode 100644 index 0000000..02233e6 --- /dev/null +++ b/tests/test_more_reallife_fastapi.py @@ -0,0 +1,118 @@ +from typing import List + +import databases +import pytest +import sqlalchemy +from fastapi import FastAPI +from starlette.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + + +class Item(ormar.Model): + class Meta: + tablename = "items" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@app.get("/items/", response_model=List[Item]) +async def get_items(): + items = await Item.objects.select_related("category").all() + return [item.dict() for item in items] + + +@app.post("/items/", response_model=Item) +async def create_item(item: Item): + item = await Item.objects.create(**item.dict()) + return item.dict() + + +@app.post("/categories/", response_model=Category) +async def create_category(category: Category): + await category.save() + return category + + +@app.put("/items/{item_id}") +async def get_item(item_id: int, item: Item): + item_db = await Item.objects.get(pk=item_id) + return {"updated_rows": await item_db.update(**item.dict())} + + +@app.delete("/items/{item_id}") +async def delete_item(item_id: int, item: Item): + item_db = await Item.objects.get(pk=item_id) + return {"deleted_rows": await item_db.delete()} + + +def test_all_endpoints(): + client = TestClient(app) + with client as client: + response = client.post("/categories/", json={"name": "test cat"}) + category = response.json() + response = client.post( + "/items/", json={"name": "test", "id": 1, "category": category} + ) + item = Item(**response.json()) + assert item.pk is not None + + response = client.get("/items/") + items = [Item(**item) for item in response.json()] + assert items[0] == item + + item.name = "New name" + response = client.put(f"/items/{item.pk}", json=item.dict()) + assert response.json().get("updated_rows") == 1 + + response = client.get("/items/") + items = [Item(**item) for item in response.json()] + assert items[0].name == "New name" + + response = client.delete(f"/items/{item.pk}", json=item.dict()) + assert response.json().get("deleted_rows") == 1 + response = client.get("/items/") + items = response.json() + assert len(items) == 0 diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py new file mode 100644 index 0000000..166c943 --- /dev/null +++ b/tests/test_same_table_joins.py @@ -0,0 +1,133 @@ +import asyncio + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Department(ormar.Model): + class Meta: + tablename = "departments" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True, autoincrement=False) + name: ormar.String(max_length=100) + + +class SchoolClass(ormar.Model): + class Meta: + tablename = "schoolclasses" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + department: ormar.ForeignKey(Department, nullable=False) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + + +class Student(ormar.Model): + class Meta: + tablename = "students" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) + + +class Teacher(ormar.Model): + class Meta: + tablename = "teachers" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + department = await Department.objects.create(id=1, name="Math Department") + class1 = await SchoolClass.objects.create(name="Math", department=department) + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Jack", category=category2, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_multiple_instances_of_same_table_in_schema(): + async with database: + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[0].name == "Jane" + + assert len(classes[0].dict().get("students")) == 2 + + # related fields of main model are only populated by pk + # unless there is a required foreign key somewhere along the way + # since department is required for schoolclass it was pre loaded (again) + # but you can load them anytime + assert classes[0].students[0].schoolclass.name == "Math" + assert classes[0].students[0].schoolclass.department.name is None + await classes[0].students[0].schoolclass.department.load() + assert classes[0].students[0].schoolclass.department.name == "Math Department" + + +@pytest.mark.asyncio +async def test_right_tables_join(): + async with database: + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].teachers[0].category.name == "Domestic" + + assert classes[0].students[0].category.name is None + await classes[0].students[0].category.load() + assert classes[0].students[0].category.name == "Foreign" + + +@pytest.mark.asyncio +async def test_multiple_reverse_related_objects(): + async with database: + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students__category"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[1].name == "Jack" + assert classes[0].teachers[0].category.name == "Domestic" From 53384879a9011593f6add237aa0fa55d47dd3b2c Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 23 Aug 2020 12:54:58 +0200 Subject: [PATCH 3/5] some cleanup and tests --- .coverage | Bin 0 -> 53248 bytes ormar/fields/base.py | 20 +-- ormar/fields/decorators.py | 27 ---- ormar/fields/foreign_key.py | 43 ++--- ormar/fields/model_fields.py | 288 +++++++++++++++++---------------- ormar/models/fakepydantic.py | 153 +++++++++--------- ormar/models/metaclass.py | 87 ++++------ ormar/models/model.py | 12 +- ormar/queryset/query.py | 74 +++++---- ormar/queryset/queryset.py | 18 +-- ormar/relations.py | 7 +- scripts/clean.sh | 0 scripts/publish.sh | 0 scripts/test.sh | 0 tests/test_foreign_keys.py | 2 + tests/test_model_definition.py | 13 +- tests/test_models.py | 22 +++ 17 files changed, 370 insertions(+), 396 deletions(-) create mode 100644 .coverage delete mode 100644 ormar/fields/decorators.py mode change 100644 => 100755 scripts/clean.sh mode change 100644 => 100755 scripts/publish.sh mode change 100644 => 100755 scripts/test.sh diff --git a/.coverage b/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..f3e5fafb45eb490917ec4b248da6544fb15b365f GIT binary patch literal 53248 zcmeI4du$xV8NhdM&v(Apj^AhEgsh5M=i&1+A!#0jw2ef7Dutp75vR-gyyt9kcYC?r zb7IA@WuT=})PIUv)beVjQi>{71&Y#=5Lco=TSTHlDJW^|5)kI)xaAScB=+{3*<0UT z;IF9qdHw0g2 z^TLkXIe?hY+b_23L2vk~tN6PPm`Yb%HPcE|P1|AU>$h&%eA^c3w#{F=d5dJHNz42aY>|kx zSxW1BrIebMVp?34v}CsyRZK0}Bbhr@CwInByXh4V(%_m$Gd44Mx+iPhAZ4mOP(&)N zC6siZv|H_KVi~fvnd&~%PJlvGttUxyNXxR~%cZm$Q`2fPsv5Qy%e3z0{^fTM`n}=m zYW`l8wMj`w;S(rm4J726vNZ+;MfG$y)xImOB%?c3qe)U`TNu?r&wZxd+O(c%F?A`f zC0R3*hGuGdQd0M+(Tu5f&)L8x3eDL7UVdw!tUGCg?zAPdb>-2?3C|&!&7KvGfpSUR z!I~4KRf90oIhjdx0rFj$M2dC}NZ1vddNIjjvZkJGY{`8yM_=aJThN^ja@wP|R3@qI z$*8lP-z4QEn^-fFdXq#O0T%G(fG7ObYMukDRQDNs;_$;uRx+l}j%Db4xebEgJ(cC& z@Y1FHy?dCUP&{&%W*W0Qz5=6}O)g?O^n7uFZ-EIpJvqlT>YPfb zijhe>6Uls&q$81=C7m&nJp;{&cs=3G4aJ(1%_p~IcglynZg04@mLIZ51f7@Yx2&KM zG$l6#F1Qesu6{}}cS;xQPLo0L#xhTMeQhztsS0u%4Lx_?>S8M2>t*y4FzM5;P@pnR z$*Fvk+860n;xL<}6fF%y7Up343Lw#p-3nw@y7Y|6?3W%g^Q70FoW6DexnqYtpBv$ps7o$zNGK@Z2Mo(`eQ^mQa z(AoVq7;Db3f}KP{Np@#1Q}pU1sS(Da0VDH*GMz%RXX!-cl;4~J z6tvWY(ru4&+822x?%k}juJNYS(kC*cyhyiu61_h~1| zZ8W4_ayk7WGqjZVhF7lShcekQXiqP)0oU?`;;5ZHbQRdf1t)W6vJK{0_>q%u9+X`k z4ti+UK<6~O2fJ(vPOj`v9RG6a;Pd1h2Or#!01`j~NB{{S0VIF~kN^@u0!RP}Ac4z| zfWW(WFZKW3MC8a;vX(4{9o&!r5!)*vRKl1&jKmZz9?8HnX6cP>t*yEAw3^b> z=0;OZCn5>WFyOhu#vOj>Q#W~qBd?H8$h&0pvhQMSAQC_VNB{{S0VIF~kN^@u0!RP} zAc4z*Ku}o1JAYxg%Y|B={(~U|ga$tQcZDk;Eb%%2Q_%bW^?~&q`3LzW*+*_9Vewt@ zNpZh;gIF0l75a6^2wfi%gKq|(2!1cPEx0J~X5h(y2?Dqw0VIF~kN^@u0!RP}Ab~53 zz?yo2%fMZ}*_(X1dwlGUTqg}>1K#_V+1tPJE?DWofE#)n!U8wApkM-)9gyMG<9J+) zsd7v)OeLj(bfiw;loF)7)i|weO>HS!CUnRycf)J-Nw#nZ2zJy6Tw=j`vTKIO@P0Y1 zZB@0v=?e?ljgd13U27rLx1vhmw$CSJs3u*+h1SqqDR9xb^RnBbovbN_9YOw;6#}Sytt5$PT!DpWK%^s76s+BX!b+ev zDikbLNS0|IvOCA^0fE~+pK=Z>cxN5dGbn=9eu3Llf=o==t)}|Am87XfDF@#I9I{pA zX}7w3ftq5P8t=~bq--BRO{=^`P0YrZ70a_skcxPUj+z3gE`{|`TiF85$ySoroF=!x z$@76z_eIr|V`(ddq9$+alUH)HtEkE6O*BzGo=Gq(UcnbF+s>a?HogB}AK1W=K5`SO z7vB?~68puiVpV7)_+;qu(B=>cnZeWKwcz*3qrq^TPX#g{fEyA(0!RP}AOR$R z1dsp{xWWjm6*BzXKlt;0me1k;*KZUC7Zxn$|A(6er4;F`|6f;?E4|A*e-$P{^>jF4B! zU&tTG&xuA9vW7H}YVlL?byyGJA@QJix0r!d0al8h(0@PxHza@rkN^@u0!RP}AOR$R z1dsp{_(BkFC zb>NyhNWGzM1akKuKmN$r{-gTn)R|-Y;X_l`KJv)u`Hx;&Gb%LHLZ-E~z4uQJ&b;`5 zF>~@k9q2jym1TZ7cdP$u?nuk^ zP2BJ0&-{T^);R zYgtrN!=mbH7FAWTsIrnp6%{NZghe9r_W4kV9R-6d3Itf>_p`|7V^MiIi@aVIc|0sC zD`Sz{%_2cyk;~h;%NB{{S0VIF~kN^@u0!RP}AOR%s1t7rQ|0h47&;Oq#|0N%iF>;!mf_DI3 zCCA8L$)CwEc?RAC_#OET`4#yEd63)>?*jY~1aLzFNB{{S0VIF~kN^@u0!RP}AOR$R z1paRVLOm?TFNA4OM}t}#)X<=s230huq(KD@2n|FUglG_?L4XE+8u(~XP6ICuJTxey Pftv;b4O}#!>;L}`MCfTC literal 0 HcmV?d00001 diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 1f55d47..67bf190 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -1,6 +1,5 @@ -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, List, Optional, TYPE_CHECKING -import pydantic import sqlalchemy from pydantic import Field @@ -10,13 +9,6 @@ if TYPE_CHECKING: # pragma no cover from ormar.models import Model -def prepare_validator(type_): - def validate_model_field(value): - return isinstance(value, type_) - - return validate_model_field - - class BaseField: __type__ = None @@ -34,13 +26,7 @@ class BaseField: server_default: Any @classmethod - def is_required(cls) -> bool: - return ( - not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key() - ) - - @classmethod - def default_value(cls): + def default_value(cls) -> Optional[Field]: if cls.is_auto_primary_key(): return Field(default=None) if cls.has_default(): @@ -52,7 +38,7 @@ class BaseField: return None @classmethod - def has_default(cls): + def has_default(cls) -> bool: return cls.default is not None or cls.server_default is not None @classmethod diff --git a/ormar/fields/decorators.py b/ormar/fields/decorators.py deleted file mode 100644 index 842e864..0000000 --- a/ormar/fields/decorators.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, TYPE_CHECKING, Type - -from ormar import ModelDefinitionError - -if TYPE_CHECKING: # pragma no cover - from ormar.fields import BaseField - - -class RequiredParams: - def __init__(self, *args: str) -> None: - self._required = list(args) - - def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: - old_init = model_field_class.__init__ - model_field_class._old_init = old_init - - def __init__(instance: "BaseField", **kwargs: Any) -> None: - super(instance.__class__, instance).__init__(**kwargs) - for arg in self._required: - if arg not in kwargs: - raise ModelDefinitionError( - f"{instance.__class__.__name__} field requires parameter: {arg}" - ) - setattr(instance, arg, kwargs.pop(arg)) - - model_field_class.__init__ = __init__ - return model_field_class diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index ae6745c..679830a 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,7 +1,6 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable +from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy -from pydantic import BaseModel import ormar # noqa I101 from ormar.exceptions import RelationshipInstanceError @@ -13,8 +12,7 @@ if TYPE_CHECKING: # pragma no cover def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": init_dict = { - **{fk.Meta.pkname: pk or -1, - '__pk_only__': True}, + **{fk.Meta.pkname: pk or -1, "__pk_only__": True}, **{ k: create_dummy_instance(v.to) for k, v in fk.Meta.model_fields.items() @@ -24,10 +22,15 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": return fk(**init_dict) -def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True, - related_name: str = None, - virtual: bool = False, - ) -> Type[object]: +def ForeignKey( + to: "Model", + *, + name: str = None, + unique: bool = False, + nullable: bool = True, + related_name: str = None, + virtual: bool = False, +) -> Type[object]: fk_string = to.Meta.tablename + "." + to.Meta.pkname to_field = to.__fields__[to.Meta.pkname] namespace = dict( @@ -43,7 +46,7 @@ def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = T index=False, pydantic_only=False, default=None, - server_default=None + server_default=None, ) return type("ForeignKey", (ForeignKeyField, BaseField), namespace) @@ -59,21 +62,21 @@ class ForeignKeyField(BaseField): yield cls.validate @classmethod - def validate(cls, v: Any) -> Any: - return v + def validate(cls, value: Any) -> Any: + return value - @property - def __type__(self) -> Type[BaseModel]: - return self.to.__pydantic_model__ + # @property + # def __type__(self) -> Type[BaseModel]: + # return self.to.__pydantic_model__ - @classmethod - def get_column_type(cls) -> sqlalchemy.Column: - to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] - return to_column.column_type + # @classmethod + # def get_column_type(cls) -> sqlalchemy.Column: + # to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] + # return to_column.column_type @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model" + cls, value: List, child: "Model" ) -> Union["Model", List["Model"]]: return [cls.expand_relationship(val, child) for val in value] @@ -109,7 +112,7 @@ class ForeignKeyField(BaseField): @classmethod def expand_relationship( - cls, value: Any, child: "Model" + cls, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index d790d65..d3f41cf 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,41 +1,43 @@ import datetime import decimal import re -from typing import Type, Any, Optional +from typing import Any, Optional, Type import pydantic import sqlalchemy from pydantic import Json -from ormar import ModelDefinitionError +from ormar import ModelDefinitionError # noqa I101 from ormar.fields.base import BaseField # noqa I101 -def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool: +def is_field_nullable( + nullable: Optional[bool], default: Any, server_default: Any +) -> bool: if nullable is None: return default is not None or server_default is not None return nullable def String( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - allow_blank: bool = False, - strip_whitespace: bool = False, - min_length: int = None, - max_length: int = None, - curtail_length: int = None, - regex: str = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[str]: if max_length is None or max_length <= 0: - raise ModelDefinitionError(f'Parameter max_length is required for field String') + raise ModelDefinitionError("Parameter max_length is required for field String") namespace = dict( __type__=str, @@ -54,26 +56,26 @@ def String( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("String", (pydantic.ConstrainedStr, BaseField), namespace) def Integer( - *, - name: str = None, - primary_key: bool = False, - autoincrement: bool = None, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[int]: namespace = dict( __type__=int, @@ -89,23 +91,23 @@ def Integer( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=autoincrement if autoincrement is not None else primary_key + autoincrement=autoincrement if autoincrement is not None else primary_key, ) return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) def Text( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - allow_blank: bool = False, - strip_whitespace: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[str]: namespace = dict( __type__=str, @@ -120,25 +122,25 @@ def Text( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) def Float( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[int]: namespace = dict( __type__=float, @@ -154,21 +156,21 @@ def Float( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) def Boolean( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[bool]: namespace = dict( __type__=bool, @@ -181,21 +183,21 @@ def Boolean( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Boolean", (int, BaseField), namespace) def DateTime( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[datetime.datetime]: namespace = dict( __type__=datetime.datetime, @@ -208,21 +210,21 @@ def DateTime( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("DateTime", (datetime.datetime, BaseField), namespace) def Date( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[datetime.date]: namespace = dict( __type__=datetime.date, @@ -235,21 +237,21 @@ def Date( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Date", (datetime.date, BaseField), namespace) def Time( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[datetime.time]: namespace = dict( __type__=datetime.time, @@ -262,21 +264,21 @@ def Time( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Time", (datetime.time, BaseField), namespace) def JSON( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[Json]: namespace = dict( __type__=pydantic.Json, @@ -289,26 +291,26 @@ def JSON( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("JSON", (pydantic.Json, BaseField), namespace) def BigInteger( - *, - name: str = None, - primary_key: bool = False, - autoincrement: bool = None, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, ) -> Type[int]: namespace = dict( __type__=int, @@ -324,31 +326,33 @@ def BigInteger( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=autoincrement if autoincrement is not None else primary_key + autoincrement=autoincrement if autoincrement is not None else primary_key, ) return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) def Decimal( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - precision: int = None, - scale: int = None, - max_digits: int = None, - decimal_places: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None -): + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None, +) -> Type[decimal.Decimal]: if precision is None or precision < 0 or scale is None or scale < 0: - raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal') + raise ModelDefinitionError( + "Parameters scale and precision are required for field Decimal" + ) namespace = dict( __type__=decimal.Decimal, @@ -368,6 +372,6 @@ def Decimal( pydantic_only=pydantic_only, default=default, server_default=server_default, - autoincrement=False + autoincrement=False, ) return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index c44309d..d6958fd 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -2,16 +2,17 @@ import inspect import json import uuid from typing import ( + AbstractSet, Any, - Callable, Dict, List, + Mapping, Optional, Set, TYPE_CHECKING, Type, TypeVar, - Union, AbstractSet, Mapping, + Union, ) import databases @@ -20,10 +21,9 @@ import sqlalchemy from pydantic import BaseModel import ormar # noqa I100 -from ormar import ForeignKey from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField -from ormar.models.metaclass import ModelMetaclass, ModelMeta +from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover @@ -39,7 +39,8 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): # FakePydantic inherits from list in order to be treated as # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts - __slots__ = ('_orm_id', '_orm_saved') + __slots__ = ("_orm_id", "_orm_saved") + __abstract__ = True if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] @@ -63,18 +64,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs = { - k: self.Meta.model_fields[k].expand_relationship(v, self) + k: self._convert_json( + k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps" + ) for k, v in kwargs.items() } - values, fields_set, validation_error = pydantic.validate_model( - self, kwargs - ) + values, fields_set, validation_error = pydantic.validate_model(self, kwargs) if validation_error and not pk_only: raise validation_error - object.__setattr__(self, '__dict__', values) - object.__setattr__(self, '__fields_set__', fields_set) + object.__setattr__(self, "__dict__", values) + object.__setattr__(self, "__fields_set__", fields_set) # super().__init__(**kwargs) # self.values = self.__pydantic_model__(**kwargs) @@ -82,58 +83,50 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def __del__(self) -> None: self.Meta._orm_relationship_manager.deregister(self) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: relation_key = self.get_name(title=True) + "_" + name if name in self.__slots__: object.__setattr__(self, name, value) - elif name == 'pk': - object.__setattr__(self, self.Meta.pkname, value) + elif name == "pk": + object.__setattr__(self, self.Meta.pkname, value) elif self.Meta._orm_relationship_manager.contains(relation_key, self): self.Meta.model_fields[name].expand_relationship(value, self) else: - super().__setattr__(name, value) + value = ( + self._convert_json(name, value, "dumps") + if name in self.__fields__ + else value + ) + super().__setattr__(name, value) - def __getattr__(self, item): + def __getattribute__(self, item: str) -> Any: + if item != "__fields__" and item in self.__fields__: + related = self._extract_related_model_instead_of_field(item) + if related: + return related + value = object.__getattribute__(self, item) + value = self._convert_json(item, value, "loads") + return value + return super().__getattribute__(item) + + def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]: + return self._extract_related_model_instead_of_field(item) + + def _extract_related_model_instead_of_field( + self, item: str + ) -> Optional[Union["Model", List["Model"]]]: relation_key = self.get_name(title=True) + "_" + item if self.Meta._orm_relationship_manager.contains(relation_key, self): return self.Meta._orm_relationship_manager.get(relation_key, self) - # def __setattr__(self, key: str, value: Any) -> None: - # if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'): - # return setattr(self, key, value) - # # elif key in self._extract_related_names(): - # # value = self._convert_json(key, value, op="dumps") - # # value = self.Meta.model_fields[key].expand_relationship(value, self) - # # relation_key = self.get_name(title=True) + "_" + key - # # if not self.Meta._orm_relationship_manager.contains(relation_key, self): - # # setattr(self.values, key, value) - # else: - # super().__setattr__(key, value) - - # def __getattribute__(self, key: str) -> Any: - # if key != 'Meta' and key in self.Meta.model_fields: - # relation_key = self.get_name(title=True) + "_" + key - # if self.Meta._orm_relationship_manager.contains(relation_key, self): - # return self.Meta._orm_relationship_manager.get(relation_key, self) - # item = getattr(self.__fields__, key, None) - # item = self._convert_json(key, item, op="loads") - # return item - # return super().__getattribute__(key) - def __same__(self, other: "Model") -> bool: if self.__class__ != other.__class__: # pragma no cover return False - return (self._orm_id == other._orm_id or - self.__dict__ == other.__dict__ or - (self.pk == other.pk and self.pk is not None - )) - - # def __repr__(self) -> str: # pragma no cover - # return self.values.__repr__() - - # @classmethod - # def __get_validators__(cls) -> Callable: # pragma no cover - # yield cls.__pydantic_model__.validate + return ( + self._orm_id == other._orm_id + or self.__dict__ == other.__dict__ + or (self.pk == other.pk and self.pk is not None) + ) @classmethod def get_name(cls, title: bool = False, lower: bool = True) -> str: @@ -148,10 +141,6 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def pk(self) -> Any: return getattr(self, self.Meta.pkname) - @pk.setter - def pk(self, value: Any) -> None: - setattr(self, self.Meta.pkname, value) - @property def pk_column(self) -> sqlalchemy.Column: return self.Meta.table.primary_key.columns.values()[0] @@ -160,36 +149,38 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ - def dict( - self, - *, - include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, - exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False - ) -> 'DictStrAny': # noqa: A003' - dict_instance = super().dict(include=include, - exclude=self._exclude_related_names_not_required(nested), - by_alias=by_alias, - skip_defaults=skip_defaults, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none) + def dict( # noqa A003 + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False + ) -> "DictStrAny": # noqa: A003' + dict_instance = super().dict( + include=include, + exclude=self._exclude_related_names_not_required(nested), + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) for field in self._extract_related_names(): nested_model = getattr(self, field) if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list) and not isinstance( - nested_model, ormar.Model + nested_model, ormar.Model ): dict_instance[field] = [x.dict(nested=True) for x in nested_model] elif nested_model is not None: - dict_instance[field] = nested_model.dict(nested=True) + dict_instance[field] = nested_model.dict(nested=True) return dict_instance def from_dict(self, value_dict: Dict) -> None: @@ -225,19 +216,21 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def _extract_related_names(cls) -> Set: related_names = set() for name, field in cls.Meta.model_fields.items(): - if inspect.isclass(field) and issubclass( - field, ForeignKeyField - ): + if inspect.isclass(field) and issubclass(field, ForeignKeyField): related_names.add(name) return related_names @classmethod - def _exclude_related_names_not_required(cls, nested:bool=False) -> Set: + def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: if nested: return cls._extract_related_names() related_names = set() for name, field in cls.Meta.model_fields.items(): - if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable: + if ( + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and field.nullable + ): related_names.add(name) return related_names @@ -267,7 +260,7 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.Meta.model_fields.keys(): if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), ormar.Model + getattr(one, field), ormar.Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), ormar.Model): diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 56d66d9..0d621bd 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union import databases import pydantic import sqlalchemy -from pydantic import BaseConfig, create_model, Extra -from pydantic.fields import ModelField, FieldInfo +from pydantic import BaseConfig +from pydantic.fields import FieldInfo from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar.fields import BaseField @@ -29,23 +29,11 @@ class ModelMeta: _orm_relationship_manager: RelationshipManager -def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: - pydantic_fields = { - field_name: ( - base_field.__type__, - ... if base_field.is_required else base_field.default_value, - ) - for field_name, base_field in object_dict.items() - if isinstance(base_field, BaseField) - } - return pydantic_fields - - def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: child_relation_name = ( - field.to.get_name(title=True) - + "_" - + (field.related_name or (name.lower() + "s")) + field.to.get_name(title=True) + + "_" + + (field.related_name or (name.lower() + "s")) ) reverse_name = child_relation_name relation_name = name.lower().title() + "_" + field.to.get_name() @@ -61,34 +49,28 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: - # model.__fields__[child_model_name] = ModelField( - # name=child_model_name, - # type_=Optional[Union[List[child], child]], - # model_config=child.__config__, - # class_validators=child.__validators__, - # ) model.Meta.model_fields[child_model_name] = ForeignKey( child, name=child_model_name, virtual=True ) def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: columns = [] pkname = None model_fields = { field_name: field - for field_name, field in object_dict['__annotations__'].items() + for field_name, field in object_dict["__annotations__"].items() if issubclass(field, BaseField) } for field_name, field in model_fields.items(): @@ -96,7 +78,7 @@ def sqlalchemy_columns_from_model_fields( if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") if field.pydantic_only: - raise ModelDefinitionError('Primary key column cannot be pydantic only') + raise ModelDefinitionError("Primary key column cannot be pydantic only") pkname = field_name if not field.pydantic_only: columns.append(field.get_column(field_name)) @@ -112,10 +94,10 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict: if type_.name is None: type_.name = field def_value = type_.default_value() - curr_def_value = attrs.get(field, 'NONE') - if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo): + curr_def_value = attrs.get(field, "NONE") + if curr_def_value == "NONE" and isinstance(def_value, FieldInfo): attrs[field] = def_value - elif curr_def_value == 'NONE' and type_.nullable: + elif curr_def_value == "NONE" and type_.nullable: attrs[field] = FieldInfo(default=None) return attrs @@ -132,51 +114,44 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]: class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: - attrs['Config'] = get_pydantic_base_orm_config() + attrs["Config"] = get_pydantic_base_orm_config() new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) - if hasattr(new_model, 'Meta'): + if hasattr(new_model, "Meta"): - if attrs.get("__abstract__"): - return new_model - annotations = attrs.get("__annotations__") or new_model.__annotations__ - attrs["__annotations__"]= annotations + attrs["__annotations__"] = annotations attrs = populate_pydantic_default_values(attrs) tablename = name.lower() + "s" new_model.Meta.tablename = new_model.Meta.tablename or tablename - + # sqlalchemy table creation - + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, new_model.Meta.tablename + name, attrs, new_model.Meta.tablename ) - + if hasattr(new_model.Meta, "model_fields") and not pkname: - model_fields = new_model.Meta.model_fields - for fieldname, field in new_model.Meta.model_fields.items(): - if field.primary_key: - pkname=fieldname - columns = new_model.Meta.table.columns - + model_fields = new_model.Meta.model_fields + for fieldname, field in new_model.Meta.model_fields.items(): + if field.primary_key: + pkname = fieldname + columns = new_model.Meta.table.columns + if not hasattr(new_model.Meta, "table"): - new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns) - + new_model.Meta.table = sqlalchemy.Table( + new_model.Meta.tablename, new_model.Meta.metadata, *columns + ) + new_model.Meta.columns = columns new_model.Meta.pkname = pkname if not pkname: raise ModelDefinitionError("Table has to have a primary key.") - # pydantic model creation - new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - new_model.Meta.pydantic_model = create_model( - name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields - ) - new_model.Meta.model_fields = model_fields new_model = super().__new__( # type: ignore mcs, name, bases, attrs diff --git a/ormar/models/model.py b/ormar/models/model.py index 4651aa1..5cb9ed0 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -13,10 +13,10 @@ class Model(FakePydantic): @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} @@ -66,8 +66,8 @@ class Model(FakePydantic): self_fields.pop(self.Meta.pkname) expr = ( self.Meta.table.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.Meta.pkname)) + .values(**self_fields) + .where(self.pk_column == getattr(self, self.Meta.pkname)) ) result = await self.Meta.database.execute(expr) return result diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 30b739e..55e936a 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -20,12 +20,12 @@ class JoinParameters(NamedTuple): class Query: def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, ) -> None: self.query_offset = offset @@ -49,15 +49,15 @@ class Query: self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.select_from = self.table - for key in self.model_cls.Meta.model_fields: - if ( - not self.model_cls.Meta.model_fields[key].nullable - and isinstance( - self.model_cls.Meta.model_fields[key], ForeignKeyField, - ) - and key not in self._select_related - ): - self._select_related = [key] + self._select_related + # for key in self.model_cls.Meta.model_fields: + # if ( + # not self.model_cls.Meta.model_fields[key].nullable + # and isinstance( + # self.model_cls.Meta.model_fields[key], ForeignKeyField, + # ) + # and key not in self._select_related + # ): + # self._select_related = [key] + self._select_related start_params = JoinParameters( self.model_cls, "", self.table.name, self.model_cls @@ -79,7 +79,7 @@ class Query: expr = self._apply_expression_modifiers(expr) - #print(expr.compile(compile_kwargs={"literal_binds": True})) + # print(expr.compile(compile_kwargs={"literal_binds": True})) self._reset_query_parameters() return expr, self._select_related @@ -97,12 +97,12 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: Type[BaseField], field_name: str, rel_part: str ) -> bool: return issubclass(field, ForeignKeyField) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( @@ -112,19 +112,19 @@ class Query: [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] ) return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters ) -> JoinParameters: model_cls = join_params.model_cls.Meta.model_fields[part].to to_table = model_cls.Meta.table.name @@ -138,7 +138,8 @@ class Query: ( v for k, v in model_cls.Meta.model_fields.items() - if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model + if issubclass(v, ForeignKeyField) + and v.to == join_params.prev_model ), None, ).name @@ -167,27 +168,30 @@ class Query: return JoinParameters(prev_model, previous_alias, from_table, model_cls) def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, ) -> None: for field_name, field in prev_model.Meta.model_fields.items(): if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part + field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: if rel_part not in self._select_related: - new_related = "__".join(rel_part.split("__")[:-1]) if len( - rel_part.split("__")) > 1 else rel_part + new_related = ( + "__".join(rel_part.split("__")[:-1]) + if len(rel_part.split("__")) > 1 + else rel_part + ) self.auto_related.append(new_related) rel_part = "" elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part + field, parent_virtual, nested, rel_part ): - + self._extract_auto_required_relations( prev_model=field.to, rel_part=rel_part, @@ -207,7 +211,7 @@ class Query: self._select_related = new_joins + self.auto_related def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: if self.filter_clauses: if len(self.filter_clauses) == 1: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 4e5d85e..56d1ee2 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -151,9 +151,9 @@ class QuerySet: pkname = self.model_cls.Meta.pkname pk = self.model_cls.Meta.model_fields[pkname] if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) ): del new_kwargs[pkname] diff --git a/ormar/relations.py b/ormar/relations.py index 8a422d9..008e566 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -5,7 +5,6 @@ from random import choices from typing import List, TYPE_CHECKING, Union from weakref import proxy -from ormar import ForeignKey from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover @@ -22,7 +21,11 @@ class RelationshipManager: self._aliases = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str + self, + relations_key: str, + reverse_key: str, + field: ForeignKeyField, + table_name: str, ) -> None: if relations_key not in self._relations: self._relations[relations_key] = {"type": "primary"} diff --git a/scripts/clean.sh b/scripts/clean.sh old mode 100644 new mode 100755 diff --git a/scripts/publish.sh b/scripts/publish.sh old mode 100644 new mode 100755 diff --git a/scripts/test.sh b/scripts/test.sh old mode 100644 new mode 100755 diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 17e4a52..f1bd7fa 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -5,6 +5,7 @@ from pydantic import ValidationError import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError +from ormar.fields.foreign_key import ForeignKeyField from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -120,6 +121,7 @@ async def test_model_crud(): track = await Track.objects.get(title="The Bird") assert track.album.pk == album.pk + assert isinstance(track.album, ormar.Model) assert track.album.name is None await track.album.load() assert track.album.name == "Malibu" diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index 3cfd56c..ebb0619 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -75,9 +75,18 @@ def test_model_attribute_access(example): example.test = 12 assert example.test == 12 + example._orm_saved = True + assert example._orm_saved + + +def test_model_attribute_json_access(example): + example.test_json = dict(aa=12) + assert example.test_json == dict(aa=12) + + def test_non_existing_attr(example): - with pytest.raises(ValueError): - example.new_attr=12 + with pytest.raises(ValueError): + example.new_attr = 12 def test_primary_key_access_and_setting(example): diff --git a/tests/test_models.py b/tests/test_models.py index 42c6b54..f21e70c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class JsonSample(ormar.Model): + class Meta: + tablename = "jsons" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + test_json: ormar.JSON(nullable=True) + + class User(ormar.Model): class Meta: tablename = "users" @@ -56,6 +66,18 @@ def test_model_pk(): assert user.id == 1 +@pytest.mark.asyncio +async def test_json_column(): + async with database: + await JsonSample.objects.create(test_json=dict(aa=12)) + await JsonSample.objects.create(test_json='{"aa": 12}') + + items = await JsonSample.objects.all() + assert len(items) == 2 + assert items[0].test_json == dict(aa=12) + assert items[1].test_json == dict(aa=12) + + @pytest.mark.asyncio async def test_model_crud(): async with database: From 806fe9b63ef3a53861d4cf5ac96d7092befce081 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 23 Aug 2020 13:15:04 +0200 Subject: [PATCH 4/5] fix setting new related model just from dict with pkname --- .coverage | Bin 53248 -> 53248 bytes ormar/fields/foreign_key.py | 2 ++ tests/test_foreign_keys.py | 2 ++ tests/test_more_reallife_fastapi.py | 6 +++--- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.coverage b/.coverage index f3e5fafb45eb490917ec4b248da6544fb15b365f..ed0ddb4c4d3c00f69f8658523e319e8874bb24e6 100644 GIT binary patch delta 60 zcmV-C0K@-)paX!Q1F$hK1~WP_G&(Y~H!mq)75Tq^n}5A?=ic7)zdPr@lfOG}pAWNj Sjy(Yee;@bxJrA?Ek4QiUE*l5{ delta 60 zcmV-C0K@-)paX!Q1F$hK1~NJ|G&(c0H!mq)75?wv=3npJxwrTH@6P$} "Model": + if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: + value["__pk_only__"] = True model = cls.to(**value) cls.register_relation(model, child) return model diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index f1bd7fa..6463eb6 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -133,6 +133,8 @@ async def test_model_crud(): assert album1.pk == 1 assert album1.tracks is None + await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4) + @pytest.mark.asyncio async def test_select_related(): diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index 02233e6..31e31b9 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -61,13 +61,13 @@ def create_test_database(): @app.get("/items/", response_model=List[Item]) async def get_items(): items = await Item.objects.select_related("category").all() - return [item.dict() for item in items] + return items @app.post("/items/", response_model=Item) async def create_item(item: Item): - item = await Item.objects.create(**item.dict()) - return item.dict() + await item.save() + return item @app.post("/categories/", response_model=Category) From 348a3d90dc0a63cdf245824f127117e91368624d Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 23 Aug 2020 16:14:04 +0200 Subject: [PATCH 5/5] refactor fields into classes --- .coverage | Bin 53248 -> 53248 bytes ormar/fields/foreign_key.py | 2 +- ormar/fields/model_fields.py | 598 ++++++++++++++------------------- ormar/models/fakepydantic.py | 30 +- tests/test_same_table_joins.py | 1 + 5 files changed, 262 insertions(+), 369 deletions(-) diff --git a/.coverage b/.coverage index ed0ddb4c4d3c00f69f8658523e319e8874bb24e6..a7dc6e2b55b75b05ef7249b6be929670c1ea005e 100644 GIT binary patch delta 154 zcmV;L0A>GxpaX!Q1F$hK2R1q}Fgi0Zvp6r#P%`EK5BU%358e;c56BO=514ml1K4gL-0vk?%k4U^c75(e`vc)reev+#{<0XBWV z{oV@(2?PNN4ha?~2m}EMPzY}GbG!eqIhX&R?#JbCKb`we=kk9#ckbM|bLY|NrOz{{pk1k6=ImRZn{W diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 2448fde..9d052a7 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -23,7 +23,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def ForeignKey( - to: "Model", + to: Type["Model"], *, name: str = None, unique: bool = False, diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index d3f41cf..dd81838 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,11 +1,9 @@ import datetime import decimal -import re from typing import Any, Optional, Type import pydantic import sqlalchemy -from pydantic import Json from ormar import ModelDefinitionError # noqa I101 from ormar.fields.base import BaseField # noqa I101 @@ -19,359 +17,253 @@ def is_field_nullable( return nullable -def String( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - allow_blank: bool = False, - strip_whitespace: bool = False, - min_length: int = None, - max_length: int = None, - curtail_length: int = None, - regex: str = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[str]: - if max_length is None or max_length <= 0: - raise ModelDefinitionError("Parameter max_length is required for field String") +class ModelFieldFactory: + _bases = None + _type = None - namespace = dict( - __type__=str, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - allow_blank=allow_blank, - strip_whitespace=strip_whitespace, - min_length=min_length, - max_length=max_length, - curtail_length=curtail_length, - regex=regex and re.compile(regex), - column_type=sqlalchemy.String(length=max_length), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) + def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: + cls.validate(**kwargs) - return type("String", (pydantic.ConstrainedStr, BaseField), namespace) + default = kwargs.pop("default", None) + server_default = kwargs.pop("server_default", None) + nullable = kwargs.pop("nullable", None) - -def Integer( - *, - name: str = None, - primary_key: bool = False, - autoincrement: bool = None, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[int]: - namespace = dict( - __type__=int, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - ge=minimum, - le=maximum, - multiple_of=multiple_of, - column_type=sqlalchemy.Integer(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=autoincrement if autoincrement is not None else primary_key, - ) - return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) - - -def Text( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - allow_blank: bool = False, - strip_whitespace: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[str]: - namespace = dict( - __type__=str, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - allow_blank=allow_blank, - strip_whitespace=strip_whitespace, - column_type=sqlalchemy.Text(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - - return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) - - -def Float( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[int]: - namespace = dict( - __type__=float, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - ge=minimum, - le=maximum, - multiple_of=multiple_of, - column_type=sqlalchemy.Float(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) - - -def Boolean( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[bool]: - namespace = dict( - __type__=bool, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.Boolean(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Boolean", (int, BaseField), namespace) - - -def DateTime( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[datetime.datetime]: - namespace = dict( - __type__=datetime.datetime, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.DateTime(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("DateTime", (datetime.datetime, BaseField), namespace) - - -def Date( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[datetime.date]: - namespace = dict( - __type__=datetime.date, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.Date(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Date", (datetime.date, BaseField), namespace) - - -def Time( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[datetime.time]: - namespace = dict( - __type__=datetime.time, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.Time(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Time", (datetime.time, BaseField), namespace) - - -def JSON( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[Json]: - namespace = dict( - __type__=pydantic.Json, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.JSON(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - - return type("JSON", (pydantic.Json, BaseField), namespace) - - -def BigInteger( - *, - name: str = None, - primary_key: bool = False, - autoincrement: bool = None, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[int]: - namespace = dict( - __type__=int, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - ge=minimum, - le=maximum, - multiple_of=multiple_of, - column_type=sqlalchemy.BigInteger(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=autoincrement if autoincrement is not None else primary_key, - ) - return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) - - -def Decimal( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - precision: int = None, - scale: int = None, - max_digits: int = None, - decimal_places: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[decimal.Decimal]: - if precision is None or precision < 0 or scale is None or scale < 0: - raise ModelDefinitionError( - "Parameters scale and precision are required for field Decimal" + namespace = dict( + __type__=cls._type, + name=kwargs.pop("name", None), + primary_key=kwargs.pop("primary_key", False), + default=default, + server_default=server_default, + nullable=is_field_nullable(nullable, default, server_default), + index=kwargs.pop("index", False), + unique=kwargs.pop("unique", False), + pydantic_only=kwargs.pop("pydantic_only", False), + autoincrement=kwargs.pop("autoincrement", False), + column_type=cls.get_column_type(**kwargs), + **kwargs ) + return type(cls.__name__, cls._bases, namespace) - namespace = dict( - __type__=decimal.Decimal, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - ge=minimum, - le=maximum, - multiple_of=multiple_of, - column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale), - precision=precision, - scale=scale, - max_digits=max_digits, - decimal_places=decimal_places, - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover + return None + + @classmethod + def validate(cls, **kwargs: Any) -> None: # pragma no cover + pass + + +class String(ModelFieldFactory): + _bases = (pydantic.ConstrainedStr, BaseField) + _type = str + + def __new__( + cls, + *, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + **kwargs: Any + ) -> Type[str]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.String(length=kwargs.get("max_length")) + + @classmethod + def validate(cls, **kwargs: Any) -> None: + max_length = kwargs.get("max_length", None) + if max_length is None or max_length <= 0: + raise ModelDefinitionError( + "Parameter max_length is required for field String" + ) + + +class Integer(ModelFieldFactory): + _bases = (pydantic.ConstrainedInt, BaseField) + _type = int + + def __new__( + cls, + *, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + **kwargs: Any + ) -> Type[int]: + autoincrement = kwargs.pop("autoincrement", None) + autoincrement = ( + autoincrement + if autoincrement is not None + else kwargs.get("primary_key", False) + ) + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Integer() + + +class Text(ModelFieldFactory): + _bases = (pydantic.ConstrainedStr, BaseField) + _type = str + + def __new__( + cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any + ) -> Type[str]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Text() + + +class Float(ModelFieldFactory): + _bases = (pydantic.ConstrainedFloat, BaseField) + _type = float + + def __new__( + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + **kwargs: Any + ) -> Type[int]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Float() + + +class Boolean(ModelFieldFactory): + _bases = (int, BaseField) + _type = bool + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Boolean() + + +class DateTime(ModelFieldFactory): + _bases = (datetime.datetime, BaseField) + _type = datetime.datetime + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.DateTime() + + +class Date(ModelFieldFactory): + _bases = (datetime.date, BaseField) + _type = datetime.date + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Date() + + +class Time(ModelFieldFactory): + _bases = (datetime.time, BaseField) + _type = datetime.time + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Time() + + +class JSON(ModelFieldFactory): + _bases = (pydantic.Json, BaseField) + _type = pydantic.Json + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.JSON() + + +class BigInteger(Integer): + _bases = (pydantic.ConstrainedInt, BaseField) + _type = int + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.BigInteger() + + +class Decimal(ModelFieldFactory): + _bases = (pydantic.ConstrainedDecimal, BaseField) + _type = decimal.Decimal + + def __new__( + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + **kwargs: Any + ) -> Type[decimal.Decimal]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + precision = kwargs.get("precision") + scale = kwargs.get("scale") + return sqlalchemy.DECIMAL(precision=precision, scale=scale) + + @classmethod + def validate(cls, **kwargs: Any) -> None: + precision = kwargs.get("precision") + scale = kwargs.get("scale") + if precision is None or precision < 0 or scale is None or scale < 0: + raise ModelDefinitionError( + "Parameters scale and precision are required for field Decimal" + ) diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index d6958fd..3b74c9d 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -240,10 +240,9 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): k: v for k, v in self_fields.items() if k in self.Meta.table.columns } for field in self._extract_related_names(): + target_pk_name = self.Meta.model_fields[field].to.Meta.pkname if getattr(self, field) is not None: - self_fields[field] = getattr( - getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname - ) + self_fields[field] = getattr(getattr(self, field), target_pk_name) return self_fields @classmethod @@ -259,17 +258,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): @classmethod def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.Meta.model_fields.keys(): - if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), ormar.Model + current_field = getattr(one, field) + if isinstance(current_field, list) and not isinstance( + current_field, ormar.Model ): - setattr(other, field, getattr(one, field) + getattr(other, field)) - elif isinstance(getattr(one, field), ormar.Model): - if getattr(one, field).pk == getattr(other, field).pk: - setattr( - other, - field, - cls.merge_two_instances( - getattr(one, field), getattr(other, field) - ), - ) + setattr(other, field, current_field + getattr(other, field)) + elif ( + isinstance(current_field, ormar.Model) + and current_field.pk == getattr(other, field).pk + ): + setattr( + other, + field, + cls.merge_two_instances(current_field, getattr(other, field)), + ) return other diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 166c943..13e2185 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -76,6 +76,7 @@ def event_loop(): @pytest.fixture(autouse=True, scope="module") async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) metadata.create_all(engine) department = await Department.objects.create(id=1, name="Math Department") class1 = await SchoolClass.objects.create(name="Math", department=department)