From a39179bc645567fa447983ead4f0fae0ff867a5b Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 16 Aug 2020 22:27:39 +0200 Subject: [PATCH] mostly working --- .coverage | Bin 53248 -> 0 bytes ormar/fields/base.py | 108 ++++---- ormar/fields/foreign_key.py | 120 +++++---- ormar/fields/model_fields.py | 394 ++++++++++++++++++++++++---- ormar/models/fakepydantic.py | 182 +++++++++---- ormar/models/metaclass.py | 143 ++++++---- ormar/models/model.py | 62 ++--- ormar/queryset/clause.py | 18 +- ormar/queryset/query.py | 92 +++---- ormar/queryset/queryset.py | 33 +-- ormar/relations.py | 7 +- tests/test_columns.py | 22 +- tests/test_fastapi_usage.py | 24 +- tests/test_foreign_keys.py | 93 ++++--- tests/test_model_definition.py | 98 +++---- tests/test_models.py | 39 +-- tests/test_more_reallife_fastapi.py | 24 +- tests/test_same_table_joins.py | 65 ++--- 18 files changed, 988 insertions(+), 536 deletions(-) delete mode 100644 .coverage diff --git a/.coverage b/.coverage deleted file mode 100644 index bf048747f459554fda48d6da57377636b1799559..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53248 zcmeI4dvF`Y8Nly!vSht({g$1O=B5u!63emdBtXEJ(55Yq4&^aSr!hUw(#cUG=@i|` zjwgQ9JciCNo$0jGGHocdol=@Upe+L#Lco8tblNh(Qz(x(iQUp6rX{wggu#RmyWj5J zSyE!z`j3hEZlvA2z5RB-{q47V`gC`D+buW64ONP1nWSP!ZCnM%^W3$P#BrPlz5)2! zn;UkV_5s9v-hQ!N57)d&3XnTEzv~Aa`9|PtB;@9})k_n3Q^CjlpLLjp(u z2_OL^zz7WA;&%rd8u+^h45cff8W|;`X3Q|~`CGSdyM4QK`?l+D+Af)C(q&!=wopjg zCS|lt*GU;Qre@SsMAc0#E{jLkdDq=D^T=TtV?1;}@0lWE#HAYoQ)XvHLp$r_r~*pmBZk-p5ex1c+1ty_Mze z;M%qP{d<_9P&{&1+|UR%@fmf(}wwLhCBDHD)oo3Ihb@0y|qQn~Kn(qSH1+v>r91bgRwQ zS#w(WxD5$RC+u+rx34W`ItwNshZlCe!woKXaP?~bZVxjftGBGkU0_0XPcAZz+NYAL zqGvPqM6%Q*X-nj0NqdY~XP`MDw=1}9b+P7H`Q)&*Q$Fl=x`VZ~{IEG9=)6R~Wd)6( zDY+qV-i4ra_0x*6Q@T)hnhc6Jmbrpg)fQ8nsvw7H=(^`t2UGceH>00~NuPd&0+ne> zPURcbzEF>nfY~Ii#4|8tVGcI0020j@RUos{rDY9fzx0rqC%yLM^fe2}9oy_7*E*IR zhBs8CZ?zAk#+`~zCm+30>e94?s-(V28g(<`AsaMgZNJfQ|a0y-4C@j_cw3Fm8 z4XKx0PJhS@ujSpr4IB93tThJB=|$GzTAoyFwXH)(fo+_3GG`_?!#oQ=a`MfCvdhCk z56v3roM!f5mr23SW&OnQuA>eIE2+f>n8f<=GNx$a69i@bGK$c|FEO+hOM2Q z>5R5ZjTm}or=s_zA|ZIV9unKG=uuNWot=qzS7#?YpQnb>N@O=YZ3Ow!eTJ$Vot?6r zRN^UFh9?;so215&-qXY45IhdhPM^`wsO9Nj8G=t(&0GIlIDEnykNt-Qz5n0fb8+M!az77z<4GACtB!C2v01`j~ zNB{|3qy#Q+5V$Pdxm&n#m%DGr?y_~#P}boUZkfHvEAN6u8+5oycXhqM4J|8}gcSs2 zc#AlZh{seprs#%}j)Qb4C~!&%(or=*E4#d|6fKh)WS67xhJ1>x0Rn<;wE~x1wjOKw z4;fw+r?qXW5jbsmA+s@ZR;SA*r2E!a3*3&Sq;%Dw>#@)pTB-!Dd-1&NR%a)ytdZ#h z0dof=!j%FSFG(VjP+*xC5NWF@+V;gNf!e4LDOE_8DQ0%VwpkRo$WqF&p3ys51eI(E z2;7}Z<NsOu2Y%KK7rd)f=o==t)}~;O3H{wKrHNC z2C`LEUent=%kYZD)kIY9jM4>d6hq4}^~#q~&k|zQ$>ss5x5-`9F3k?L#PZA;q(ZKu zL#;rnOJP=djG${=jKQsxslY1{}I0=J}LH#HGva>1A&_YLH`r}0sotUe)6jS7V@Zn zweO_wx4uE&wIG5U5Eav~$ zHw#KB(w6@pEJ;h#T`KVZ>oy6=) z=9+^}Rjoq2G>HZOzp{AS7puhd|0_xs%5$a2Mj^7Ka<+?3|Db`!h)u$sOXkn_{{wBp z?j0uQe@8a|9#6rwyMf&x_3EVmb+Zw|9h5G&k`~X+B^Vk<;Cr?z?082 zXOMCiA8G|smjCZsj=_cge_2S7m!w_b|2vD16H}i&xpX zPaK@-@0~n&;#mLw(NS&i%v<|s_+TC6XsJ8&;{53R2M?W|IC=8Qkq+2~*4M&`J8G%u zc;D%luId<_d~tr@s`**%^b@_)$B*rs>z(|l_ubyfsnfi_2J)_|8UNAcGY=haaaF^S zSnd5x6~tez8hfF8n0of+@5_U&ZWo~y z5pSgk@n^+duZ_R|di#lUvx8@^sR}^SrGdd04$nV7H}%f6JTp;M?}xNjf6s&cFV21t zn*a0A{9EsyI(72-ukex&&RpmF6gU1r{WoWJ&wMz(<(Z-0`PZJxj-8!)Zt(1!$K{0_ zqPNcDfw;!AetIxaQ4YJGE`R!&qlqp1kHnrldc4XFDVMs(UU=lxQ19$?`}8yIq{2mW zxqkIV`@wzV1M?k69iB2;M48CV4SAe&*BOh&TwVco>jW2fZhmdR0sEIY4jnn%(J?Vo zR>nI$9?`+x|0lO`fCP{L5Yz5kE?|H|)Y zY#|ar0!RP}AOR$R1dsp{Kmter34F2%u=oGTJ@oni_sRdrDKbgkBq!h None: - self.name = None - self._populate_from_kwargs(kwargs) + column_type: sqlalchemy.Column + constraints: List = [] - def _populate_from_kwargs(self, kwargs: Dict) -> None: - self.primary_key = kwargs.pop("primary_key", False) - self.autoincrement = kwargs.pop( - "autoincrement", self.primary_key and self.__type__ == int - ) + primary_key: bool + autoincrement: bool + nullable: bool + index: bool + unique: bool + pydantic_only: bool - self.nullable = kwargs.pop("nullable", not self.primary_key) - self.default = kwargs.pop("default", None) - self.server_default = kwargs.pop("server_default", None) + default: Any + server_default: Any - self.index = kwargs.pop("index", None) - self.unique = kwargs.pop("unique", None) - - self.pydantic_only = kwargs.pop("pydantic_only", False) - if self.pydantic_only and self.primary_key: - raise ModelDefinitionError("Primary key column cannot be pydantic only.") - - @property - def is_required(self) -> bool: + @classmethod + def is_required(cls) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key() ) - @property - def default_value(self) -> Any: - default = self.default - return default() if callable(default) else default + @classmethod + def default_value(cls): + if cls.is_auto_primary_key(): + return Field(default=None) + if cls.has_default(): + default = cls.default if cls.default is not None else cls.server_default + if callable(default): + return Field(default_factory=default) + else: + return Field(default=default) + return None - @property - def has_default(self) -> bool: - return self.default is not None or self.server_default is not None + @classmethod + def has_default(cls): + return cls.default is not None or cls.server_default is not None - @property - def is_auto_primary_key(self) -> bool: - if self.primary_key: - return self.autoincrement + @classmethod + def is_auto_primary_key(cls) -> bool: + if cls.primary_key: + return cls.autoincrement return False - def get_column(self, name: str = None) -> sqlalchemy.Column: - self.name = name - constraints = self.get_constraints() + @classmethod + def get_column(cls, name: str) -> sqlalchemy.Column: return sqlalchemy.Column( - self.name, - self.get_column_type(), - *constraints, - primary_key=self.primary_key, - autoincrement=self.autoincrement, - nullable=self.nullable, - index=self.index, - unique=self.unique, - default=self.default, - server_default=self.server_default, + name, + cls.column_type, + *cls.constraints, + primary_key=cls.primary_key, + nullable=cls.nullable and not cls.primary_key, + index=cls.index, + unique=cls.unique, + default=cls.default, + server_default=cls.server_default, ) - def get_column_type(self) -> sqlalchemy.types.TypeEngine: - raise NotImplementedError() # pragma: no cover - - def get_constraints(self) -> Optional[List]: - return [] - - def expand_relationship(self, value: Any, child: "Model") -> Any: + @classmethod + def expand_relationship(cls, value: Any, child: "Model") -> Any: return value - - def __repr__(self): # pragma no cover - return str(self.__dict__) diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 77c2da7..87a4f4d 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable import sqlalchemy from pydantic import BaseModel @@ -13,87 +13,115 @@ if TYPE_CHECKING: # pragma no cover def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": init_dict = { - **{fk.__pkname__: pk or -1}, + **{fk.Meta.pkname: pk or -1, + '__pk_only__': True}, **{ k: create_dummy_instance(v.to) - for k, v in fk.__model_fields__.items() - if isinstance(v, ForeignKey) and not v.nullable and not v.virtual + for k, v in fk.Meta.model_fields.items() + if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual }, } return fk(**init_dict) -class ForeignKey(BaseField): - def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, - ) -> None: - super().__init__(nullable=nullable, name=name) - self.virtual = virtual - self.related_name = related_name - self.to = to +def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True, + related_name: str = None, + virtual: bool = False, + ) -> Type[object]: + fk_string = to.Meta.tablename + "." + to.Meta.pkname + to_field = to.__fields__[to.Meta.pkname] + namespace = dict( + to=to, + name=name, + nullable=nullable, + constraints=[sqlalchemy.schema.ForeignKey(fk_string)], + unique=unique, + column_type=to_field.type_.column_type, + related_name=related_name, + virtual=virtual, + primary_key=False, + index=False, + pydantic_only=False, + default=None, + server_default=None + ) + + return type("ForeignKey", (ForeignKeyField, BaseField), namespace) + + +class ForeignKeyField(BaseField): + to: Type["Model"] + related_name: str + virtual: bool + + @classmethod + def __get_validators__(cls) -> Callable: + yield cls.validate + + @classmethod + def validate(cls, v: Any) -> Any: + return v @property def __type__(self) -> Type[BaseModel]: return self.to.__pydantic_model__ - def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]: - fk_string = self.to.__tablename__ + "." + self.to.__pkname__ - return [sqlalchemy.schema.ForeignKey(fk_string)] - - def get_column_type(self) -> sqlalchemy.Column: - to_column = self.to.__model_fields__[self.to.__pkname__] - return to_column.get_column_type() + @classmethod + def get_column_type(cls) -> sqlalchemy.Column: + to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] + return to_column.column_type + @classmethod def _extract_model_from_sequence( - self, value: List, child: "Model" + cls, value: List, child: "Model" ) -> Union["Model", List["Model"]]: - return [self.expand_relationship(val, child) for val in value] + return [cls.expand_relationship(val, child) for val in value] - def _register_existing_model(self, value: "Model", child: "Model") -> "Model": - self.register_relation(value, child) + @classmethod + def _register_existing_model(cls, value: "Model", child: "Model") -> "Model": + cls.register_relation(value, child) return value - def _construct_model_from_dict(self, value: dict, child: "Model") -> "Model": - model = self.to(**value) - self.register_relation(model, child) + @classmethod + def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model": + model = cls.to(**value) + cls.register_relation(model, child) return model - def _construct_model_from_pk(self, value: Any, child: "Model") -> "Model": - if not isinstance(value, self.to.pk_type()): + @classmethod + def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model": + if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( - f"Relationship error - ForeignKey {self.to.__name__} " - f"is of type {self.to.pk_type()} " + f"Relationship error - ForeignKey {cls.to.__name__} " + f"is of type {cls.to.pk_type()} " f"while {type(value)} passed as a parameter." ) - model = create_dummy_instance(fk=self.to, pk=value) - self.register_relation(model, child) + model = create_dummy_instance(fk=cls.to, pk=value) + cls.register_relation(model, child) return model - def register_relation(self, model: "Model", child: "Model") -> None: - child_model_name = self.related_name or child.get_name() - model._orm_relationship_manager.add_relation( - model, child, child_model_name, virtual=self.virtual + @classmethod + def register_relation(cls, model: "Model", child: "Model") -> None: + child_model_name = cls.related_name or child.get_name() + model.Meta._orm_relationship_manager.add_relation( + model, child, child_model_name, virtual=cls.virtual ) + @classmethod def expand_relationship( - self, value: Any, child: "Model" + cls, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None constructors = { - f"{self.to.__name__}": self._register_existing_model, - "dict": self._construct_model_from_dict, - "list": self._extract_model_from_sequence, + f"{cls.to.__name__}": cls._register_existing_model, + "dict": cls._construct_model_from_dict, + "list": cls._extract_model_from_sequence, } model = constructors.get( - value.__class__.__name__, self._construct_model_from_pk + value.__class__.__name__, cls._construct_model_from_pk )(value, child) return model diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 4f9be11..30946bf 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,87 +1,373 @@ import datetime import decimal +import re +from typing import Type, Any, Optional +import pydantic import sqlalchemy from pydantic import Json +from ormar import ModelDefinitionError from ormar.fields.base import BaseField # noqa I101 -from ormar.fields.decorators import RequiredParams -@RequiredParams("length") -class String(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.String(self.length) +def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool: + if nullable is None: + return default is not None or server_default is not None + return False -class Integer(BaseField): - __type__ = int +def String( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[str]: + if max_length is None or max_length <= 0: + raise ModelDefinitionError(f'Parameter max_length is required for field String') - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Integer() + namespace = dict( + __type__=str, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + allow_blank=allow_blank, + strip_whitespace=strip_whitespace, + min_length=min_length, + max_length=max_length, + curtail_length=curtail_length, + regex=regex and re.compile(regex), + column_type=sqlalchemy.String(length=max_length), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + + return type("String", (pydantic.ConstrainedStr, BaseField), namespace) -class Text(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Text() +def Integer( + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=int, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.Integer(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=autoincrement if autoincrement is not None else primary_key + ) + return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) -class Float(BaseField): - __type__ = float +def Text( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[str]: + namespace = dict( + __type__=str, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + allow_blank=allow_blank, + strip_whitespace=strip_whitespace, + column_type=sqlalchemy.Text(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Float() + return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) -class Boolean(BaseField): - __type__ = bool - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Boolean() +def Float( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=float, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.Float(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) -class DateTime(BaseField): - __type__ = datetime.datetime - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DateTime() +def Boolean( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[bool]: + namespace = dict( + __type__=bool, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Boolean(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Boolean", (int, BaseField), namespace) -class Date(BaseField): - __type__ = datetime.date - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Date() +def DateTime( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.datetime]: + namespace = dict( + __type__=datetime.datetime, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.DateTime(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("DateTime", (datetime.datetime, BaseField), namespace) -class Time(BaseField): - __type__ = datetime.time - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Time() +def Date( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.date]: + namespace = dict( + __type__=datetime.date, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Date(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Date", (datetime.date, BaseField), namespace) -class JSON(BaseField): - __type__ = Json - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.JSON() +def Time( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.time]: + namespace = dict( + __type__=datetime.time, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Time(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Time", (datetime.time, BaseField), namespace) -class BigInteger(BaseField): - __type__ = int +def JSON( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[Json]: + namespace = dict( + __type__=pydantic.Json, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.JSON(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.BigInteger() + return type("JSON", (pydantic.Json, BaseField), namespace) -@RequiredParams("length", "precision") -class Decimal(BaseField): - __type__ = decimal.Decimal +def BigInteger( + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=int, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.BigInteger(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=autoincrement if autoincrement is not None else primary_key + ) + return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DECIMAL(self.length, self.precision) + +def Decimal( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +): + if precision is None or precision < 0 or scale is None or scale < 0: + raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal') + + namespace = dict( + __type__=decimal.Decimal, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale), + precision=precision, + scale=scale, + max_digits=max_digits, + decimal_places=decimal_places, + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index d5109b4..e31f3f7 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -11,7 +11,7 @@ from typing import ( TYPE_CHECKING, Type, TypeVar, - Union, + Union, AbstractSet, Mapping, ) import databases @@ -20,19 +20,27 @@ import sqlalchemy from pydantic import BaseModel import ormar # noqa I100 +from ormar import ForeignKey from ormar.fields import BaseField -from ormar.models.metaclass import ModelMetaclass +from ormar.fields.foreign_key import ForeignKeyField +from ormar.models.metaclass import ModelMetaclass, ModelMeta from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model + IntStr = Union[int, str] + DictStrAny = Dict[str, Any] + AbstractSetIntStr = AbstractSet[IntStr] + MappingIntStrAny = Mapping[IntStr, Any] -class FakePydantic(list, metaclass=ModelMetaclass): + +class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): # FakePydantic inherits from list in order to be treated as # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts - __abstract__ = True + __slots__ = ('_orm_id', '_orm_saved') + if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] __table__: sqlalchemy.Table @@ -43,62 +51,88 @@ class FakePydantic(list, metaclass=ModelMetaclass): __metadata__: sqlalchemy.MetaData __database__: databases.Database _orm_relationship_manager: RelationshipManager + Meta: ModelMeta + # noinspection PyMissingConstructor def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self._orm_id: str = uuid.uuid4().hex - self._orm_saved: bool = False - self.values: Optional[BaseModel] = None + object.__setattr__(self, "_orm_id", uuid.uuid4().hex) + object.__setattr__(self, "_orm_saved", False) + + pk_only = kwargs.pop("__pk_only__", False) if "pk" in kwargs: - kwargs[self.__pkname__] = kwargs.pop("pk") + kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs = { - k: self.__model_fields__[k].expand_relationship(v, self) + k: self.Meta.model_fields[k].expand_relationship(v, self) for k, v in kwargs.items() } - self.values = self.__pydantic_model__(**kwargs) + + values, fields_set, validation_error = pydantic.validate_model( + self, kwargs + ) + if validation_error and not pk_only: + raise validation_error + + object.__setattr__(self, '__dict__', values) + object.__setattr__(self, '__fields_set__', fields_set) + + # super().__init__(**kwargs) + # self.values = self.__pydantic_model__(**kwargs) def __del__(self) -> None: - self._orm_relationship_manager.deregister(self) + self.Meta._orm_relationship_manager.deregister(self) - def __setattr__(self, key: str, value: Any) -> None: - if key in self.__fields__: - value = self._convert_json(key, value, op="dumps") - value = self.__model_fields__[key].expand_relationship(value, self) + def __setattr__(self, name, value): + if name in self.__slots__: + object.__setattr__(self, name, value) + elif name == 'pk': + object.__setattr__(self, self.Meta.pkname, value) + relation_key = self.get_name(title=True) + "_" + name + if self.Meta._orm_relationship_manager.contains(relation_key, self): + self.Meta.model_fields[name].expand_relationship(value, self) + return + super().__setattr__(name, value) - relation_key = self.get_name(title=True) + "_" + key - if not self._orm_relationship_manager.contains(relation_key, self): - setattr(self.values, key, value) - else: - super().__setattr__(key, value) + def __getattr__(self, item): + relation_key = self.get_name(title=True) + "_" + item + if self.Meta._orm_relationship_manager.contains(relation_key, self): + return self.Meta._orm_relationship_manager.get(relation_key, self) - def __getattribute__(self, key: str) -> Any: - if key != "__fields__" and key in self.__fields__: - relation_key = self.get_name(title=True) + "_" + key - if self._orm_relationship_manager.contains(relation_key, self): - return self._orm_relationship_manager.get(relation_key, self) + # def __setattr__(self, key: str, value: Any) -> None: + # if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'): + # return setattr(self, key, value) + # # elif key in self._extract_related_names(): + # # value = self._convert_json(key, value, op="dumps") + # # value = self.Meta.model_fields[key].expand_relationship(value, self) + # # relation_key = self.get_name(title=True) + "_" + key + # # if not self.Meta._orm_relationship_manager.contains(relation_key, self): + # # setattr(self.values, key, value) + # else: + # super().__setattr__(key, value) - item = getattr(self.values, key, None) - item = self._convert_json(key, item, op="loads") - return item - return super().__getattribute__(key) - - def __eq__(self, other: "Model") -> bool: - return self.values.dict() == other.values.dict() + # def __getattribute__(self, key: str) -> Any: + # if key != 'Meta' and key in self.Meta.model_fields: + # relation_key = self.get_name(title=True) + "_" + key + # if self.Meta._orm_relationship_manager.contains(relation_key, self): + # return self.Meta._orm_relationship_manager.get(relation_key, self) + # item = getattr(self.__fields__, key, None) + # item = self._convert_json(key, item, op="loads") + # return item + # return super().__getattribute__(key) def __same__(self, other: "Model") -> bool: if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + self.__dict__ is not None and other.__dict__ is not None and self.pk == other.pk ) - def __repr__(self) -> str: # pragma no cover - return self.values.__repr__() + # def __repr__(self) -> str: # pragma no cover + # return self.values.__repr__() - @classmethod - def __get_validators__(cls) -> Callable: # pragma no cover - yield cls.__pydantic_model__.validate + # @classmethod + # def __get_validators__(cls) -> Callable: # pragma no cover + # yield cls.__pydantic_model__.validate @classmethod def get_name(cls, title: bool = False, lower: bool = True) -> str: @@ -109,25 +143,57 @@ class FakePydantic(list, metaclass=ModelMetaclass): name = name.title() return name + @property + def pk(self) -> Any: + return getattr(self, self.Meta.pkname) + + @pk.setter + def pk(self, value: Any) -> None: + setattr(self, self.Meta.pkname, value) + @property def pk_column(self) -> sqlalchemy.Column: - return self.__table__.primary_key.columns.values()[0] + return self.Meta.table.primary_key.columns.values()[0] @classmethod def pk_type(cls) -> Any: - return cls.__model_fields__[cls.__pkname__].__type__ + return cls.Meta.model_fields[cls.Meta.pkname].__type__ - def dict(self, nested=False) -> Dict: # noqa: A003 - dict_instance = self.values.dict() + def dict( + self, + *, + include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, + exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False + ) -> 'DictStrAny': # noqa: A003 + print('callin super', self.__class__) + print('to exclude', self._exclude_related_names_not_required(nested)) + dict_instance = super().dict(include=include, + exclude=self._exclude_related_names_not_required(nested), + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + print('after super') for field in self._extract_related_names(): + print(self.__class__, field, nested) nested_model = getattr(self, field) - if self.__model_fields__[field].virtual and nested: + + if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list) and not isinstance( - nested_model, ormar.Model + nested_model, ormar.Model ): + print('nested list') dict_instance[field] = [x.dict(nested=True) for x in nested_model] else: + print('instance') dict_instance[field] = ( nested_model.dict(nested=True) if nested_model is not None else {} ) @@ -155,7 +221,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): return value def _is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.__model_fields__.get(column_name).__type__ == pydantic.Json + return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json def _extract_own_model_fields(self) -> Dict: related_names = self._extract_related_names() @@ -165,22 +231,32 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def _extract_related_names(cls) -> Set: related_names = set() - for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass( + field, ForeignKeyField ): related_names.add(name) return related_names + @classmethod + def _exclude_related_names_not_required(cls, nested:bool=False) -> Set: + if nested: + return cls._extract_related_names() + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable: + related_names.add(name) + return related_names + def _extract_model_db_fields(self) -> Dict: self_fields = self._extract_own_model_fields() self_fields = { - k: v for k, v in self_fields.items() if k in self.__table__.columns + k: v for k, v in self_fields.items() if k in self.Meta.table.columns } for field in self._extract_related_names(): if getattr(self, field) is not None: self_fields[field] = getattr( - getattr(self, field), self.__model_fields__[field].to.__pkname__ + getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname ) return self_fields @@ -196,9 +272,9 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": - for field in one.__model_fields__.keys(): + for field in one.Meta.model_fields.keys(): if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), ormar.Model + getattr(one, field), ormar.Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), ormar.Model): diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 7ea2fe7..76088cc 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,12 +1,15 @@ -import copy -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union +import databases +import pydantic import sqlalchemy -from pydantic import BaseConfig, create_model -from pydantic.fields import ModelField +from pydantic import BaseConfig, create_model, Extra +from pydantic.fields import ModelField, FieldInfo from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField +from ormar.queryset import QuerySet from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover @@ -15,6 +18,17 @@ if TYPE_CHECKING: # pragma no cover relationship_manager = RelationshipManager() +class ModelMeta: + tablename: str + table: sqlalchemy.Table + metadata: sqlalchemy.MetaData + database: databases.Database + columns: List[sqlalchemy.Column] + pkname: str + model_fields: Dict[str, Union[BaseField, ForeignKey]] + _orm_relationship_manager: RelationshipManager + + def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: pydantic_fields = { field_name: ( @@ -29,9 +43,9 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: child_relation_name = ( - field.to.get_name(title=True) - + "_" - + (field.related_name or (name.lower() + "s")) + field.to.get_name(title=True) + + "_" + + (field.related_name or (name.lower() + "s")) ) reverse_name = child_relation_name relation_name = name.lower().title() + "_" + field.to.get_name() @@ -41,104 +55,125 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> def expand_reverse_relationships(model: Type["Model"]) -> None: - for model_field in model.__model_fields__.values(): - if isinstance(model_field, ForeignKey): + for model_field in model.Meta.model_fields.values(): + if issubclass(model_field, ForeignKeyField): child_model_name = model_field.related_name or model.get_name() + "s" parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( + # model.__fields__[child_model_name] = ModelField( + # name=child_model_name, + # type_=Optional[Union[List[child], child]], + # model_config=child.__config__, + # class_validators=child.__validators__, + # ) + model.Meta.model_fields[child_model_name] = ForeignKey( child, name=child_model_name, virtual=True ) def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: columns = [] pkname = None model_fields = { field_name: field - for field_name, field in object_dict.items() - if isinstance(field, BaseField) + for field_name, field in object_dict['__annotations__'].items() + if issubclass(field, BaseField) } for field_name, field in model_fields.items(): if field.primary_key: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") + if field.pydantic_only: + raise ModelDefinitionError('Primary key column cannot be pydantic only') pkname = field_name if not field.pydantic_only: columns.append(field.get_column(field_name)) - if isinstance(field, ForeignKey): + if issubclass(field, ForeignKeyField): register_relation_on_build(table_name, field, name) return pkname, columns, model_fields +def populate_pydantic_default_values(attrs: Dict) -> Dict: + for field, type_ in attrs['__annotations__'].items(): + if issubclass(type_, BaseField): + if type_.name is None: + type_.name = field + def_value = type_.default_value() + curr_def_value = attrs.get(field, 'NONE') + if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo): + attrs[field] = def_value + elif curr_def_value == 'NONE' and type_.nullable: + attrs[field] = FieldInfo(default=None) + return attrs + + def get_pydantic_base_orm_config() -> Type[BaseConfig]: class Config(BaseConfig): orm_mode = True + arbitrary_types_allowed = True + # extra = Extra.allow return Config -class ModelMetaclass(type): +class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: + + attrs['Config'] = get_pydantic_base_orm_config() new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) - if attrs.get("__abstract__"): - return new_model + if hasattr(new_model, 'Meta'): - tablename = attrs.get("__tablename__", name.lower() + "s") - attrs["__tablename__"] = tablename - metadata = attrs["__metadata__"] + if attrs.get("__abstract__"): + return new_model - # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, tablename - ) - attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) - attrs["__columns__"] = columns - attrs["__pkname__"] = pkname + attrs = populate_pydantic_default_values(attrs) - if not pkname: - raise ModelDefinitionError("Table has to have a primary key.") + tablename = name.lower() + "s" + new_model.Meta.tablename = new_model.Meta.tablename or tablename - # pydantic model creation - pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model( - name, __config__=get_pydantic_base_orm_config(), **pydantic_fields - ) - attrs["__pydantic_fields__"] = pydantic_fields - attrs["__pydantic_model__"] = pydantic_model - attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) - attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) - attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) + # sqlalchemy table creation + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, new_model.Meta.tablename + ) + new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns) + new_model.Meta.columns = columns + new_model.Meta.pkname = pkname - attrs["__model_fields__"] = model_fields - attrs["_orm_relationship_manager"] = relationship_manager + if not pkname: + breakpoint() + raise ModelDefinitionError("Table has to have a primary key.") - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) + # pydantic model creation + new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs) + new_model.Meta.pydantic_model = create_model( + name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields + ) - expand_reverse_relationships(new_model) + new_model.Meta.model_fields = model_fields + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + expand_reverse_relationships(new_model) + + new_model.Meta._orm_relationship_manager = relationship_manager + new_model.objects = QuerySet(new_model) + + # breakpoint() return new_model diff --git a/ormar/models/model.py b/ormar/models/model.py index b16d3e7..4651aa1 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -7,39 +7,39 @@ from ormar.models import FakePydantic # noqa I100 class Model(FakePydantic): - __abstract__ = True + __abstract__ = False - objects = ormar.queryset.QuerySet() + # objects = ormar.queryset.QuerySet() @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} select_related = select_related or [] - table_prefix = cls._orm_relationship_manager.resolve_relation_join( - previous_table, cls.__table__.name + table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join( + previous_table, cls.Meta.table.name ) - previous_table = cls.__table__.name + previous_table = cls.Meta.table.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) - model_cls = cls.__model_fields__[first_part].to + model_cls = cls.Meta.model_fields[first_part].to child = model_cls.from_row( row, select_related=[remainder], previous_table=previous_table ) item[first_part] = child else: - model_cls = cls.__model_fields__[related].to + model_cls = cls.Meta.model_fields[related].to child = model_cls.from_row(row, previous_table=previous_table) item[related] = child - for column in cls.__table__.columns: + for column in cls.Meta.table.columns: if column.name not in item: item[column.name] = row[ f'{table_prefix + "_" if table_prefix else ""}{column.name}' @@ -47,22 +47,14 @@ class Model(FakePydantic): return cls(**item) - @property - def pk(self) -> str: - return getattr(self.values, self.__pkname__) - - @pk.setter - def pk(self, value: Any) -> None: - setattr(self.values, self.__pkname__, value) - async def save(self) -> "Model": self_fields = self._extract_model_db_fields() - if self.__model_fields__.get(self.__pkname__).autoincrement: - self_fields.pop(self.__pkname__, None) - expr = self.__table__.insert() + if self.Meta.model_fields.get(self.Meta.pkname).autoincrement: + self_fields.pop(self.Meta.pkname, None) + expr = self.Meta.table.insert() expr = expr.values(**self_fields) - item_id = await self.__database__.execute(expr) - self.pk = item_id + item_id = await self.Meta.database.execute(expr) + setattr(self, self.Meta.pkname, item_id) return self async def update(self, **kwargs: Any) -> int: @@ -71,23 +63,23 @@ class Model(FakePydantic): self.from_dict(new_values) self_fields = self._extract_model_db_fields() - self_fields.pop(self.__pkname__) + self_fields.pop(self.Meta.pkname) expr = ( - self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) + self.Meta.table.update() + .values(**self_fields) + .where(self.pk_column == getattr(self, self.Meta.pkname)) ) - result = await self.__database__.execute(expr) + result = await self.Meta.database.execute(expr) return result async def delete(self) -> int: - expr = self.__table__.delete() - expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) - result = await self.__database__.execute(expr) + expr = self.Meta.table.delete() + expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) + result = await self.Meta.database.execute(expr) return result async def load(self) -> "Model": - expr = self.__table__.select().where(self.pk_column == self.pk) - row = await self.__database__.fetch_one(expr) + expr = self.Meta.table.select().where(self.pk_column == self.pk) + row = await self.Meta.database.fetch_one(expr) self.from_dict(dict(row)) return self diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 3d5d14b..dc94e6f 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -32,7 +32,7 @@ class QueryClause: self.filter_clauses = filter_clauses self.model_cls = model_cls - self.table = self.model_cls.__table__ + self.table = self.model_cls.Meta.table def filter( # noqa: A003 self, **kwargs: Any @@ -41,7 +41,7 @@ class QueryClause: select_related = list(self._select_related) if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ + pk_name = self.model_cls.Meta.pkname kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): @@ -65,8 +65,8 @@ class QueryClause: related_parts, select_related ) - table = model_cls.__table__ - column = model_cls.__table__.columns[field_name] + table = model_cls.Meta.table + column = model_cls.Meta.table.columns[field_name] else: op = "exact" @@ -106,12 +106,12 @@ class QueryClause: # Walk the relationships to the actual model class # against which the comparison is being made. - previous_table = model_cls.__tablename__ + previous_table = model_cls.Meta.tablename for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - manager = model_cls._orm_relationship_manager + current_table = model_cls.Meta.model_fields[part].to.Meta.tablename + manager = model_cls.Meta._orm_relationship_manager table_prefix = manager.resolve_relation_join(previous_table, current_table) - model_cls = model_cls.__model_fields__[part].to + model_cls = model_cls.Meta.model_fields[part].to previous_table = current_table return select_related, table_prefix, model_cls @@ -128,7 +128,7 @@ class QueryClause: clause_text = str( clause.compile( - dialect=self.model_cls.__database__._backend._dialect, + dialect=self.model_cls.Meta.database._backend._dialect, compile_kwargs={"literal_binds": True}, ) ) diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 202b249..798502a 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -4,8 +4,8 @@ import sqlalchemy from sqlalchemy import text import ormar # noqa I100 -from ormar import ForeignKey from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -20,12 +20,12 @@ class JoinParameters(NamedTuple): class Query: def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, ) -> None: self.query_offset = offset @@ -34,7 +34,7 @@ class Query: self.filter_clauses = filter_clauses self.model_cls = model_cls - self.table = self.model_cls.__table__ + self.table = self.model_cls.Meta.table self.auto_related = [] self.used_aliases = [] @@ -46,16 +46,16 @@ class Query: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] + self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.select_from = self.table - for key in self.model_cls.__model_fields__: + for key in self.model_cls.Meta.model_fields: if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], ormar.fields.ForeignKey, - ) - and key not in self._select_related + not self.model_cls.Meta.model_fields[key].nullable + and isinstance( + self.model_cls.Meta.model_fields[key], ForeignKeyField, + ) + and key not in self._select_related ): self._select_related = [key] + self._select_related @@ -79,7 +79,7 @@ class Query: expr = self._apply_expression_modifiers(expr) - # print(expr.compile(compile_kwargs={"literal_binds": True})) + print(expr.compile(compile_kwargs={"literal_binds": True})) self._reset_query_parameters() return expr, self._select_related @@ -97,12 +97,12 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: BaseField, field_name: str, rel_part: str ) -> bool: - return isinstance(field, ForeignKey) and field_name not in rel_part + return issubclass(field, ForeignKeyField) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( @@ -112,39 +112,39 @@ class Query: [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] ) return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters ) -> JoinParameters: - model_cls = join_params.model_cls.__model_fields__[part].to - to_table = model_cls.__table__.name + model_cls = join_params.model_cls.Meta.model_fields[part].to + to_table = model_cls.Meta.table.name - alias = model_cls._orm_relationship_manager.resolve_relation_join( + alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join( join_params.from_table, to_table ) if alias not in self.used_aliases: - if join_params.prev_model.__model_fields__[part].virtual: + if join_params.prev_model.Meta.model_fields[part].virtual: to_key = next( ( v - for k, v in model_cls.__model_fields__.items() - if isinstance(v, ForeignKey) and v.to == join_params.prev_model + for k, v in model_cls.Meta.model_fields.items() + if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model ), None, ).name - from_key = model_cls.__pkname__ + from_key = model_cls.Meta.pkname else: - to_key = model_cls.__pkname__ + to_key = model_cls.Meta.pkname from_key = part on_clause = self.on_clause( @@ -157,8 +157,8 @@ class Query: self.select_from = sqlalchemy.sql.outerjoin( self.select_from, target_table, on_clause ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) - self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) + self.columns.extend(self.prefixed_columns(alias, model_cls.Meta.table)) self.used_aliases.append(alias) previous_alias = alias @@ -167,24 +167,28 @@ class Query: return JoinParameters(prev_model, previous_alias, from_table, model_cls) def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, ) -> None: - for field_name, field in prev_model.__model_fields__.items(): + for field_name, field in prev_model.Meta.model_fields.items(): if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part + field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: + print('add', rel_part, field) if rel_part not in self._select_related: - self.auto_related.append("__".join(rel_part.split("__")[:-1])) + new_related = "__".join(rel_part.split("__")[:-1]) if len( + rel_part.split("__")) > 1 else rel_part + self.auto_related.append(new_related) rel_part = "" elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part + field, parent_virtual, nested, rel_part ): + print('deeper', rel_part, field, field.to) self._extract_auto_required_relations( prev_model=field.to, rel_part=rel_part, @@ -204,7 +208,7 @@ class Query: self._select_related = new_joins + self.auto_related def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: if self.filter_clauses: if len(self.filter_clauses) == 1: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 65192b8..4e5d85e 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -33,11 +33,11 @@ class QuerySet: @property def database(self) -> databases.Database: - return self.model_cls.__database__ + return self.model_cls.Meta.database @property def table(self) -> sqlalchemy.Table: - return self.model_cls.__table__ + return self.model_cls.Meta.table def build_select_expression(self) -> sqlalchemy.sql.select: qry = Query( @@ -148,12 +148,12 @@ class QuerySet: new_kwargs = dict(**kwargs) # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.__model_fields__[pkname] + pkname = self.model_cls.Meta.pkname + pk = self.model_cls.Meta.model_fields[pkname] if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) ): del new_kwargs[pkname] @@ -163,11 +163,11 @@ class QuerySet: if isinstance(new_kwargs.get(field), ormar.Model): new_kwargs[field] = getattr( new_kwargs.get(field), - self.model_cls.__model_fields__[field].to.__pkname__, + self.model_cls.Meta.model_fields[field].to.Meta.pkname, ) else: new_kwargs[field] = new_kwargs.get(field).get( - self.model_cls.__model_fields__[field].to.__pkname__ + self.model_cls.Meta.model_fields[field].to.Meta.pkname ) # Build the insert expression. @@ -176,5 +176,6 @@ class QuerySet: # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) + pk = await self.database.execute(expr) + setattr(instance, self.model_cls.Meta.pkname, pk) return instance diff --git a/ormar/relations.py b/ormar/relations.py index 45c3ddd..8a422d9 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -6,6 +6,7 @@ from typing import List, TYPE_CHECKING, Union from weakref import proxy from ormar import ForeignKey +from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar.models import FakePydantic, Model @@ -21,14 +22,14 @@ class RelationshipManager: self._aliases = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str ) -> None: if relations_key not in self._relations: self._relations[relations_key] = {"type": "primary"} - self._aliases[f"{table_name}_{field.to.__tablename__}"] = get_table_alias() + self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() if reverse_key not in self._relations: self._relations[reverse_key] = {"type": "reverse"} - self._aliases[f"{field.to.__tablename__}_{table_name}"] = get_table_alias() + self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() def deregister(self, model: "FakePydantic") -> None: for rel_type in self._relations.keys(): diff --git a/tests/test_columns.py b/tests/test_columns.py index e75cb0a..c8c9d3b 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -16,17 +16,19 @@ def time(): class Example(ormar.Model): - __tablename__ = "example" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "example" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - created = ormar.DateTime(default=datetime.datetime.now) - created_day = ormar.Date(default=datetime.date.today) - created_time = ormar.Time(default=time) - description = ormar.Text(nullable=True) - value = ormar.Float(nullable=True) - data = ormar.JSON(default={}) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=200, default='aaa') + created: ormar.DateTime(default=datetime.datetime.now) + created_day: ormar.Date(default=datetime.date.today) + created_time: ormar.Time(default=time) + description: ormar.Text(nullable=True) + value: ormar.Float(nullable=True) + data: ormar.JSON(default={}) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 25b1310..f7f2625 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -13,22 +13,24 @@ metadata = sqlalchemy.MetaData() class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Item(ormar.Model): - __tablename__ = "items" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "items" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) @app.post("/items/", response_model=Item) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index bffa783..17e4a52 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,6 +1,7 @@ import databases import pytest import sqlalchemy +from pydantic import ValidationError import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError @@ -11,62 +12,68 @@ metadata = sqlalchemy.MetaData() class Album(ormar.Model): - __tablename__ = "album" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "albums" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Track(ormar.Model): - __tablename__ = "track" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "tracks" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album) - title = ormar.String(length=100) - position = ormar.Integer() + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album) + title: ormar.String(max_length=100) + position: ormar.Integer() class Cover(ormar.Model): - __tablename__ = "covers" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "covers" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album, related_name="cover_pictures") - title = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album, related_name="cover_pictures") + title: ormar.String(max_length=100) class Organisation(ormar.Model): - __tablename__ = "org" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "org" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - ident = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + ident: ormar.String(max_length=100) class Team(ormar.Model): - __tablename__ = "team" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "teams" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - org = ormar.ForeignKey(Organisation) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + org: ormar.ForeignKey(Organisation) + name: ormar.String(max_length=100) class Member(ormar.Model): - __tablename__ = "member" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "members" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - team = ormar.ForeignKey(Team) - email = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + team: ormar.ForeignKey(Team) + email: ormar.String(max_length=100) @pytest.fixture(autouse=True, scope="module") @@ -171,8 +178,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -180,8 +187,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -223,8 +230,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -243,8 +250,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index f6dc722..adf20ad 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -1,4 +1,5 @@ import datetime +import decimal import pydantic import pytest @@ -12,19 +13,21 @@ metadata = sqlalchemy.MetaData() class ExampleModel(Model): - __tablename__ = "example" - __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) - test_text = fields.Text(default="") - test_bool = fields.Boolean(nullable=False) - test_float = fields.Float() - test_datetime = fields.DateTime(default=datetime.datetime.now) - test_date = fields.Date(default=datetime.date.today) - test_time = fields.Time(default=datetime.time) - test_json = fields.JSON(default={}) - test_bigint = fields.BigInteger(default=0) - test_decimal = fields.Decimal(length=10, precision=2) + class Meta: + tablename = "example" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) + test_text: fields.Text(default="") + test_bool: fields.Boolean(nullable=False) + test_float: fields.Float() = None + test_datetime: fields.DateTime(default=datetime.datetime.now) + test_date: fields.Date(default=datetime.date.today) + test_time: fields.Time(default=datetime.time) + test_json: fields.JSON(default={}) + test_bigint: fields.BigInteger(default=0) + test_decimal: fields.Decimal(scale=10, precision=2) fields_to_check = [ @@ -41,15 +44,17 @@ fields_to_check = [ class ExampleModel2(Model): - __tablename__ = "example2" - __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) + class Meta: + tablename = "example2" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) @pytest.fixture() def example(): - return ExampleModel(pk=1, test_string="test", test_bool=True) + return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)) def test_not_nullable_field_is_required(): @@ -83,60 +88,65 @@ def test_primary_key_access_and_setting(example): def test_pydantic_model_is_created(example): - assert issubclass(example.values.__class__, pydantic.BaseModel) - assert all([field in example.values.__fields__ for field in fields_to_check]) - assert example.values.test == 1 + assert issubclass(example.__class__, pydantic.BaseModel) + assert all([field in example.__fields__ for field in fields_to_check]) + assert example.test == 1 def test_sqlalchemy_table_is_created(example): - assert issubclass(example.__table__.__class__, sqlalchemy.Table) - assert all([field in example.__table__.columns for field in fields_to_check]) + assert issubclass(example.Meta.table.__class__, sqlalchemy.Table) + assert all([field in example.Meta.table.columns for field in fields_to_check]) def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - test_string = fields.String(length=250) + class Meta: + tablename = "example3" + metadata = metadata + + test_string: fields.String(max_length=250) def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - id = fields.Integer(primary_key=True) - test_string = fields.String(length=250, primary_key=True) + class Meta: + tablename = "example3" + metadata = metadata + + id: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250, primary_key=True) def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Integer(primary_key=True, pydantic_only=True) + class Meta: + tablename = "example4" + metadata = metadata + + test: fields.Integer(primary_key=True, pydantic_only=True) def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Decimal(primary_key=True) + class Meta: + tablename = "example5" + metadata = metadata + + test: fields.Decimal(primary_key=True) def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.String(primary_key=True) + class Meta: + tablename = "example6" + metadata = metadata + + test: fields.String(primary_key=True) def test_json_conversion_in_model(): diff --git a/tests/test_models.py b/tests/test_models.py index 69d421c..42c6b54 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ import databases +import pydantic import pytest import sqlalchemy @@ -11,23 +12,25 @@ metadata = sqlalchemy.MetaData() class User(ormar.Model): - __tablename__ = "users" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "users" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100, default='') class Product(ormar.Model): - __tablename__ = "product" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "product" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - rating = ormar.Integer(minimum=1, maximum=5) - in_stock = ormar.Boolean(default=False) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + rating: ormar.Integer(minimum=1, maximum=5) + in_stock: ormar.Boolean(default=False) @pytest.fixture(autouse=True, scope="module") @@ -39,12 +42,12 @@ def create_test_database(): def test_model_class(): - assert list(User.__model_fields__.keys()) == ["id", "name"] - assert isinstance(User.__model_fields__["id"], ormar.Integer) - assert User.__model_fields__["id"].primary_key is True - assert isinstance(User.__model_fields__["name"], ormar.String) - assert User.__model_fields__["name"].length == 100 - assert isinstance(User.__table__, sqlalchemy.Table) + assert list(User.Meta.model_fields.keys()) == ["id", "name"] + assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt) + assert User.Meta.model_fields["id"].primary_key is True + assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr) + assert User.Meta.model_fields["name"].max_length == 100 + assert isinstance(User.Meta.table, sqlalchemy.Table) def test_model_pk(): diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index d7670cc..53ff9d4 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -30,22 +30,24 @@ async def shutdown() -> None: class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Item(ormar.Model): - __tablename__ = "items" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "items" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 60ddddc..166c943 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -12,53 +12,58 @@ metadata = sqlalchemy.MetaData() class Department(ormar.Model): - __tablename__ = "departments" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "departments" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True, autoincrement=False) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True, autoincrement=False) + name: ormar.String(max_length=100) class SchoolClass(ormar.Model): - __tablename__ = "schoolclasses" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "schoolclasses" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - department = ormar.ForeignKey(Department, nullable=False) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + department: ormar.ForeignKey(Department, nullable=False) class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Student(ormar.Model): - __tablename__ = "students" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "students" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - schoolclass = ormar.ForeignKey(SchoolClass) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) class Teacher(ormar.Model): - __tablename__ = "teachers" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "teachers" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - schoolclass = ormar.ForeignKey(SchoolClass) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) @pytest.fixture(scope="module")