From c5389023b80f1c1325ba895d2147276f38e60f56 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 26 Aug 2020 22:24:25 +0200 Subject: [PATCH] add fixes for fastapi model clones, add functionality to add and remove models to relation, add relation proxy, fix all tests, adding values also to pydantic model __dict__some refactors --- .coverage | Bin 53248 -> 53248 bytes ormar/fields/base.py | 2 +- ormar/fields/foreign_key.py | 29 ++++--- ormar/models/__init__.py | 3 +- ormar/models/metaclass.py | 53 ++++++------ ormar/models/model.py | 9 ++- ormar/models/modelproxy.py | 14 +++- ormar/models/newbasemodel.py | 40 ++++++--- ormar/queryset/query.py | 5 +- ormar/queryset/queryset.py | 16 +++- ormar/relations.py | 121 ++++++++++++++++++++-------- tests/test_columns.py | 2 +- tests/test_foreign_keys.py | 65 ++++++++++++--- tests/test_model_definition.py | 9 ++- tests/test_models.py | 2 +- tests/test_more_reallife_fastapi.py | 4 +- tests/test_more_same_table_joins.py | 4 +- 17 files changed, 260 insertions(+), 118 deletions(-) diff --git a/.coverage b/.coverage index a8c10f89578082e93b09966ce034e9233c94b2b1..6c4c46660b6bd8578aefb2379417f44b551754f5 100644 GIT binary patch delta 379 zcmV->0fhd5paX!Q1F$DA2r@D{GBP?bGqW!**HAXP01x>O>JQ!z(+|iGxDTNZcn@$7 zP7gB=D-R70?+(-svksjOhz@`bat=TaBMtry=nc~iy0Z}wehrh!jfOHX6$AkZVimfZ zZ@%X{-+bpgKhNjO|NVCPzki$mdw1{7zuviXZ}0ivo%7$x-<`M52a|k`CRhI6Z{G(L z5d;AVDiQ8t^8W4b-~PAvJ^StVULOzy0SQ15e!G}+pX+aWZ|C>v`+je6zn{JLOL+V5 z{s9>e1OW*&58B0NzqjsQXG{Hk`&aw!dcSS=0h92K9)BDS1OW*;4cd$Szuo=q>*xNw z_vg>ut*-CBD!(!h1`GrN2@VVj84Cmf2{a4(zT`jiTR#5(<^OrI(``IAXWLjH3IqWO zLJAH5?-yVJK>v%o-0WZL|9zj|-SYq6{@mWJ-v$W;0SOKXE++^C0SQnDZu4`y|F1ch Z|DW#1<&%t$B?tTdeZ1%QJhQluNI;8ewFv+K delta 365 zcmV-z0h0cJpaX!Q1F$DA2r)D|Gc!6eH?uD<*HAXe01x>O>JQ!z(+|iGxDTNZcn@$7 zPY*K>D-RA1@DA1vybh-hk`9Uvd=5_zG7b$5_YL9=$+HmsW@-t)gZ=f9J`J8z#4lX8wGPJi#W?*kJN1OW*s z5$c-0CWG@d%uLY0T>Si0SPh>+QnzT zx9(nNOZ~q6t9^I9-?sac?~WdS8w~^j2{;Yfi~Yad{q5`L{=E0+&)x02ugb5?0|pEP z0SOKa3K|Op0SPt>`o82p-{ZG@{P%zPf1d1gV}57579R=(0SQ0~4gl{LU;#k?i(hVk zt^bGHo_}}w|8IYH>o*1o1OW*S2^J>^1OW+92yXLpyZ^5_m;ayc$K`L6jgKS;dw(DI L{GJE1xQ|FciRq*; diff --git a/ormar/fields/base.py b/ormar/fields/base.py index a057920..126a3a6 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -64,6 +64,6 @@ class BaseField: @classmethod def expand_relationship( - cls, value: Any, child: Union["Model", "NewBaseModel"] + cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True ) -> Any: return value diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 9fad3fe..ebd4ec6 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -68,25 +68,33 @@ class ForeignKeyField(BaseField): @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model" + cls, value: List, child: "Model", to_register: bool ) -> Union["Model", List["Model"]]: - return [cls.expand_relationship(val, child) for val in value] + return [cls.expand_relationship(val, child, to_register) for val in value] @classmethod - def _register_existing_model(cls, value: "Model", child: "Model") -> "Model": - cls.register_relation(value, child) + def _register_existing_model( + cls, value: "Model", child: "Model", to_register: bool + ) -> "Model": + if to_register: + cls.register_relation(value, child) return value @classmethod - def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model": + def _construct_model_from_dict( + cls, value: dict, child: "Model", to_register: bool + ) -> "Model": if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: value["__pk_only__"] = True model = cls.to(**value) - cls.register_relation(model, child) + if to_register: + cls.register_relation(model, child) return model @classmethod - def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model": + def _construct_model_from_pk( + cls, value: Any, child: "Model", to_register: bool + ) -> "Model": if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( f"Relationship error - ForeignKey {cls.to.__name__} " @@ -94,7 +102,8 @@ class ForeignKeyField(BaseField): f"while {type(value)} passed as a parameter." ) model = create_dummy_instance(fk=cls.to, pk=value) - cls.register_relation(model, child) + if to_register: + cls.register_relation(model, child) return model @classmethod @@ -105,7 +114,7 @@ class ForeignKeyField(BaseField): @classmethod def expand_relationship( - cls, value: Any, child: "Model" + cls, value: Any, child: "Model", to_register: bool = True ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None @@ -118,5 +127,5 @@ class ForeignKeyField(BaseField): model = constructors.get( value.__class__.__name__, cls._construct_model_from_pk - )(value, child) + )(value, child, to_register) return model diff --git a/ormar/models/__init__.py b/ormar/models/__init__.py index bc6e7d0..c0592aa 100644 --- a/ormar/models/__init__.py +++ b/ormar/models/__init__.py @@ -1,4 +1,5 @@ from ormar.models.newbasemodel import NewBaseModel from ormar.models.model import Model +from ormar.models.metaclass import expand_reverse_relationships -__all__ = ["NewBaseModel", "Model"] +__all__ = ["NewBaseModel", "Model", "expand_reverse_relationships"] diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 1332b33..1428fab 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -29,17 +29,8 @@ class ModelMeta: alias_manager: AliasManager -def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: - child_relation_name = ( - field.to.get_name(title=True) - + "_" - + (field.related_name or (name.lower() + "s")) - ) - reverse_name = child_relation_name - relation_name = name.lower().title() + "_" + field.to.get_name() - relationship_manager.add_relation_type( - relation_name, reverse_name, field, table_name - ) +def register_relation_on_build(table_name: str, field: ForeignKey) -> None: + relationship_manager.add_relation_type(field, table_name) def expand_reverse_relationships(model: Type["Model"]) -> None: @@ -64,15 +55,10 @@ def register_reverse_model_fields( def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str -) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: + model_fields: Dict, table_name: str +) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None - model_fields = { - field_name: field - for field_name, field in object_dict["__annotations__"].items() - if issubclass(field, BaseField) - } for field_name, field in model_fields.items(): if field.primary_key: if pkname is not None: @@ -83,9 +69,9 @@ def sqlalchemy_columns_from_model_fields( if not field.pydantic_only: columns.append(field.get_column(field_name)) if issubclass(field, ForeignKeyField): - register_relation_on_build(table_name, field, name) + register_relation_on_build(table_name, field) - return pkname, columns, model_fields + return pkname, columns def populate_pydantic_default_values(attrs: Dict) -> Dict: @@ -125,21 +111,29 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): attrs["__annotations__"] = annotations attrs = populate_pydantic_default_values(attrs) + attrs["__module__"] = attrs["__module__"] or bases[0].__module__ + attrs["__annotations__"] = ( + attrs["__annotations__"] or bases[0].__annotations__ + ) + tablename = name.lower() + "s" new_model.Meta.tablename = new_model.Meta.tablename or tablename # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, new_model.Meta.tablename - ) + model_fields = { + field_name: field + for field_name, field in attrs["__annotations__"].items() + if issubclass(field, BaseField) + } - if hasattr(new_model.Meta, "model_fields") and not pkname: - model_fields = new_model.Meta.model_fields - for fieldname, field in new_model.Meta.model_fields.items(): - if field.primary_key: - pkname = fieldname + if hasattr(new_model.Meta, "columns"): columns = new_model.Meta.table.columns + pkname = new_model.Meta.pkname + else: + pkname, columns = sqlalchemy_columns_from_model_fields( + model_fields, new_model.Meta.tablename + ) if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( @@ -153,10 +147,11 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): raise ModelDefinitionError("Table has to have a primary key.") new_model.Meta.model_fields = model_fields + expand_reverse_relationships(new_model) + new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) - expand_reverse_relationships(new_model) new_model.Meta.alias_manager = relationship_manager new_model.objects = QuerySet(new_model) diff --git a/ormar/models/model.py b/ormar/models/model.py index 338b01a..5fc6635 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -69,7 +69,8 @@ class Model(NewBaseModel): async def save(self) -> "Model": self_fields = self._extract_model_db_fields() - if self.Meta.model_fields.get(self.Meta.pkname).autoincrement: + + if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement: self_fields.pop(self.Meta.pkname, None) expr = self.Meta.table.insert() expr = expr.values(**self_fields) @@ -77,7 +78,7 @@ class Model(NewBaseModel): setattr(self, self.Meta.pkname, item_id) return self - async def update(self, **kwargs: Any) -> int: + async def update(self, **kwargs: Any) -> "Model": if kwargs: new_values = {**self.dict(), **kwargs} self.from_dict(new_values) @@ -89,8 +90,8 @@ class Model(NewBaseModel): .values(**self_fields) .where(self.pk_column == getattr(self, self.Meta.pkname)) ) - result = await self.Meta.database.execute(expr) - return result + await self.Meta.database.execute(expr) + return self async def delete(self) -> int: expr = self.Meta.table.delete() diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index c954a5f..83b6d09 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -24,7 +24,6 @@ class ModelTableProxy: @classmethod def substitute_models_with_pks(cls, model_dict: dict) -> dict: - model_dict = copy.deepcopy(model_dict) for field in cls._extract_related_names(): if field in model_dict and model_dict.get(field) is not None: target_field = cls.Meta.model_fields[field] @@ -76,10 +75,19 @@ class ModelTableProxy: } for field in self._extract_db_related_names(): target_pk_name = self.Meta.model_fields[field].to.Meta.pkname - if getattr(self, field) is not None: - self_fields[field] = getattr(getattr(self, field), target_pk_name) + target_field = getattr(self, field) + self_fields[field] = getattr(target_field, target_pk_name, None) return self_fields + @staticmethod + def resolve_relation_name(item: "Model", related: "Model"): + for name, field in item.Meta.model_fields.items(): + if issubclass(field, ForeignKeyField): + # fastapi is creating clones of response model that's why it can be a subclass + # of the original one so we need to compare Meta too + if field.to == related.__class__ or field.to.Meta == related.Meta: + return name + @classmethod def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: merged_rows = [] diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 582529b..fca1bbd 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -71,9 +71,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass pk_only = kwargs.pop("__pk_only__", False) if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") + # build the models to set them and validate but don't register kwargs = { k: self._convert_json( - k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps" + k, + self.Meta.model_fields[k].expand_relationship( + v, self, to_register=False + ), + "dumps", ) for k, v in kwargs.items() } @@ -85,13 +90,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__fields_set__", fields_set) + # register the related models after initialization + for related in self._extract_related_names(): + self.Meta.model_fields[related].expand_relationship( + kwargs.get(related), self, to_register=True + ) + def __setattr__(self, name: str, value: Any) -> None: if name in self.__slots__: object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) elif name in self._orm: - self.Meta.model_fields[name].expand_relationship(value, self) + model = self.Meta.model_fields[name].expand_relationship(value, self) + self.__dict__[name] = model else: value = ( self._convert_json(name, value, "dumps") @@ -113,19 +125,13 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass return value return super().__getattribute__(item) - # def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]: - # return self._extract_related_model_instead_of_field(item) - def _extract_related_model_instead_of_field( self, item: str ) -> Optional[Union["Model", List["Model"]]]: - # relation_key = self.get_name(title=True) + "_" + item if item in self._orm: return self._orm.get(item) def __same__(self, other: "Model") -> bool: - if self.__class__ != other.__class__: # pragma no cover - return False return ( self._orm_id == other._orm_id or self.__dict__ == other.__dict__ @@ -137,8 +143,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass name = cls.__name__ if lower: name = name.lower() - if title: - name = name.title() return name @property @@ -149,6 +153,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ + def remove(self, name: "Model"): + self._orm.remove_parent(self, name) + def dict( # noqa A003 self, *, @@ -176,14 +183,23 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list): - dict_instance[field] = [x.dict(nested=True) for x in nested_model] + result = [] + for model in nested_model: + try: + result.append(model.dict(nested=True)) + except ReferenceError: # pragma no cover + continue + dict_instance[field] = result elif nested_model is not None: dict_instance[field] = nested_model.dict(nested=True) + else: + dict_instance[field] = None return dict_instance - def from_dict(self, value_dict: Dict) -> None: + def from_dict(self, value_dict: Dict) -> "Model": for key, value in value_dict.items(): setattr(self, key, value) + return self def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: if not self._is_conversion_to_json_needed(column_name): diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 95b2ff1..8271887 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -69,10 +69,11 @@ class Query: # print(expr.compile(compile_kwargs={"literal_binds": True})) self._reset_query_parameters() - return expr, self._select_related + return expr + @staticmethod def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index c22b6ad..1b0db6d 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -47,7 +47,7 @@ class QuerySet: offset=self.query_offset, limit_count=self.limit_count, ) - exp, self._select_related = qry.build_select_expression() + exp = qry.build_select_expression() return exp def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 @@ -118,15 +118,25 @@ class QuerySet: async def get(self, **kwargs: Any) -> "Model": if kwargs: return await self.filter(**kwargs).get() + else: + if not self.filter_clauses: + expr = self.build_select_expression().limit(2) + else: + expr = self.build_select_expression() - expr = self.build_select_expression().limit(2) rows = await self.database.fetch_all(expr) + result_rows = [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + rows = self.model_cls.merge_instances_list(result_rows) + if not rows: raise NoMatch() if len(rows) > 1: raise MultipleMatches() - return self.model_cls.from_row(rows[0], select_related=self._select_related) + return rows[0] async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 if kwargs: diff --git a/ormar/relations.py b/ormar/relations.py index 8909140..f9113ad 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -2,12 +2,13 @@ import string import uuid from enum import Enum from random import choices -from typing import List, TYPE_CHECKING, Type +from typing import List, TYPE_CHECKING, Type, Union, Optional from weakref import proxy import sqlalchemy from sqlalchemy import text +import ormar from ormar.exceptions import RelationshipInstanceError from ormar.fields.foreign_key import ForeignKeyField # noqa I100 @@ -26,7 +27,6 @@ class RelationType(Enum): class AliasManager: def __init__(self) -> None: - self._relations = dict() self._aliases = dict() @staticmethod @@ -40,54 +40,83 @@ class AliasManager: def prefixed_table_name(alias: str, name: str) -> text: return text(f"{name} {alias}_{name}") - def add_relation_type( - self, - relations_key: str, - reverse_key: str, - field: ForeignKeyField, - table_name: str, - ) -> None: - if relations_key not in self._relations: + def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None: + if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases: self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() - if reverse_key not in self._relations: + if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases: self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() def resolve_relation_join(self, from_table: str, to_table: str) -> str: return self._aliases.get(f"{from_table}_{to_table}", "") -class Relation: - def __init__(self, type_: RelationType) -> None: - self._type = type_ - self.related_models = [] if type_ == RelationType.REVERSE else None +class RelationProxy(list): + def __init__(self, relation: "Relation"): + super(RelationProxy, self).__init__() + self.relation = relation + self._owner = self.relation.manager.owner - def _find_existing(self, child): - for ind, relation_child in enumerate(self.related_models): + def remove(self, item: "Model"): + super().remove(item) + rel_name = item.resolve_relation_name(item, self._owner) + item._orm._get(rel_name).remove(self._owner) + + def append(self, item: "Model"): + super().append(item) + + def add(self, item): + rel_name = item.resolve_relation_name(item, self._owner) + setattr(item, rel_name, self._owner) + + +class Relation: + def __init__(self, manager: "RelationsManager", type_: RelationType) -> None: + self.manager = manager + self._owner = manager.owner + self._type = type_ + self.related_models = ( + RelationProxy(relation=self) if type_ == RelationType.REVERSE else None + ) + + def _find_existing(self, child) -> Optional[int]: + for ind, relation_child in enumerate(self.related_models[:]): try: if relation_child.__same__(child): return ind except ReferenceError: # pragma no cover - continue + self.related_models.pop(ind) return None def add(self, child: "Model") -> None: + relation_name = self._owner.resolve_relation_name(self._owner, child) if self._type == RelationType.PRIMARY: self.related_models = child + self._owner.__dict__[relation_name] = child else: if self._find_existing(child) is None: self.related_models.append(child) + rel = self._owner.__dict__.get(relation_name, []) + rel.append(child) + self._owner.__dict__[relation_name] = rel - # def remove(self, child: "Model") -> None: - # if self._type == RelationType.PRIMARY: - # self.related_models = None - # else: - # position = self._find_existing(child) - # if position is not None: - # self.related_models.pop(position) + def remove(self, child: "Model") -> None: + relation_name = self._owner.resolve_relation_name(self._owner, child) + if self._type == RelationType.PRIMARY: + if self.related_models.__same__(child): + self.related_models = None + del self._owner.__dict__[relation_name] + else: + position = self._find_existing(child) + if position is not None: + self.related_models.pop(position) + del self._owner.__dict__[relation_name][position] - def get(self): + def get(self) -> Union[List["Model"], "Model"]: return self.related_models + def __repr__(self): # pragma no cover + return str(self.related_models) + class RelationsManager: def __init__( @@ -98,21 +127,23 @@ class RelationsManager: self._related_names = [field.name for field in self._related_fields] self._relations = dict() for field in self._related_fields: - self._relations[field.name] = Relation( - type_=RelationType.PRIMARY - if not field.virtual - else RelationType.REVERSE - ) + self._add_relation(field) + + def _add_relation(self, field): + self._relations[field.name] = Relation( + manager=self, + type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE, + ) def __contains__(self, item): return item in self._related_names - def get(self, name): + def get(self, name) -> Optional[Union[List["Model"], "Model"]]: relation = self._relations.get(name, None) if relation: return relation.get() - def _get(self, name): + def _get(self, name) -> Optional[Relation]: relation = self._relations.get(name, None) if relation: return relation @@ -122,7 +153,7 @@ class RelationsManager: ( field for field in child._orm._related_fields - if field.to == parent.__class__ + if field.to == parent.__class__ or field.to.Meta == parent.Meta ), None, ) @@ -140,5 +171,25 @@ class RelationsManager: child_name = child_name or child.get_name() + "s" child = proxy(child) - parent._orm._get(child_name).add(child) + parent_relation = parent._orm._get(child_name) + if not parent_relation: + ormar.models.expand_reverse_relationships(child.__class__) + name = parent.resolve_relation_name(parent, child) + field = parent.Meta.model_fields[name] + parent._orm._add_relation(field) + parent_relation = parent._orm._get(child_name) + parent_relation.add(child) child._orm._get(to_name).add(parent) + + def remove(self, name: str, child: "Model"): + relation = self._get(name) + relation.remove(child) + + @staticmethod + def remove_parent(item: "Model", name: Union[str, "Model"]): + related_model = name + name = item.resolve_relation_name(item, related_model) + if name in item._orm: + relation_name = item.resolve_relation_name(related_model, item) + item._orm.remove(name, related_model) + related_model._orm.remove(relation_name, item) diff --git a/tests/test_columns.py b/tests/test_columns.py index c8c9d3b..15382b0 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -22,7 +22,7 @@ class Example(ormar.Model): database = database id: ormar.Integer(primary_key=True) - name: ormar.String(max_length=200, default='aaa') + name: ormar.String(max_length=200, default="aaa") created: ormar.DateTime(default=datetime.datetime.now) created_day: ormar.Date(default=datetime.date.today) created_time: ormar.Time(default=time) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 5ae364c..fb85fc7 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,11 +1,11 @@ +import gc + import databases import pytest import sqlalchemy -from pydantic import ValidationError import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError -from ormar.fields.foreign_key import ForeignKeyField from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -133,7 +133,9 @@ async def test_model_crud(): assert album1.pk == 1 assert album1.tracks == [] - await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4) + await Track.objects.create( + album={"id": track.album.pk}, title="The Bird2", position=4 + ) @pytest.mark.asyncio @@ -164,6 +166,47 @@ async def test_select_related(): assert len(tracks) == 6 +@pytest.mark.asyncio +async def test_model_removal_from_relations(): + async with database: + album = Album(name="Chichi") + await album.save() + track1 = Track(album=album, title="The Birdman", position=1) + track2 = Track(album=album, title="Superman", position=2) + track3 = Track(album=album, title="Wonder Woman", position=3) + await track1.save() + await track2.save() + await track3.save() + + assert len(album.tracks) == 3 + album.tracks.remove(track1) + assert len(album.tracks) == 2 + assert track1.album is None + + await track1.update() + track1 = await Track.objects.get(title="The Birdman") + assert track1.album is None + + album.tracks.add(track1) + assert len(album.tracks) == 3 + assert track1.album == album + + await track1.update() + track1 = await Track.objects.select_related("album__tracks").get( + title="The Birdman" + ) + album = await Album.objects.select_related("tracks").get(name="Chichi") + assert track1.album == album + + track1.remove(album) + assert track1.album is None + assert len(album.tracks) == 2 + + track2.remove(album) + assert track2.album is None + assert len(album.tracks) == 1 + + @pytest.mark.asyncio async def test_fk_filter(): async with database: @@ -182,8 +225,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -191,8 +234,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -234,8 +277,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -254,8 +297,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index ebb0619..ab2845c 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -54,7 +54,9 @@ class ExampleModel2(Model): @pytest.fixture() def example(): - return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)) + return ExampleModel( + pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5) + ) def test_not_nullable_field_is_required(): @@ -110,6 +112,7 @@ def test_sqlalchemy_table_is_created(example): def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example3" @@ -120,6 +123,7 @@ def test_no_pk_in_model_definition(): def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example3" @@ -131,6 +135,7 @@ def test_two_pks_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example4" @@ -141,6 +146,7 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example5" @@ -151,6 +157,7 @@ def test_decimal_error_in_model_definition(): def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example6" diff --git a/tests/test_models.py b/tests/test_models.py index f21e70c..1c00ef3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -28,7 +28,7 @@ class User(ormar.Model): database = database id: ormar.Integer(primary_key=True) - name: ormar.String(max_length=100, default='') + name: ormar.String(max_length=100, default="") class Product(ormar.Model): diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index 31e31b9..e01ce44 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -79,7 +79,7 @@ async def create_category(category: Category): @app.put("/items/{item_id}") async def get_item(item_id: int, item: Item): item_db = await Item.objects.get(pk=item_id) - return {"updated_rows": await item_db.update(**item.dict())} + return await item_db.update(**item.dict()) @app.delete("/items/{item_id}") @@ -105,7 +105,7 @@ def test_all_endpoints(): item.name = "New name" response = client.put(f"/items/{item.pk}", json=item.dict()) - assert response.json().get("updated_rows") == 1 + assert response.json() == item.dict() response = client.get("/items/") items = [Item(**item) for item in response.json()] diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py index 21bf385..3718bc0 100644 --- a/tests/test_more_same_table_joins.py +++ b/tests/test_more_same_table_joins.py @@ -101,10 +101,10 @@ async def test_model_multiple_instances_of_same_table_in_schema(): assert classes[0].name == "Math" assert classes[0].students[0].name == "Jane" assert len(classes[0].dict().get("students")) == 2 - assert classes[0].teachers[0].category.department.name == 'Law Department' + assert classes[0].teachers[0].category.department.name == "Law Department" assert classes[0].students[0].category.pk is not None assert classes[0].students[0].category.name is None await classes[0].students[0].category.load() await classes[0].students[0].category.department.load() - assert classes[0].students[0].category.department.name == 'Math Department' + assert classes[0].students[0].category.department.name == "Math Department"