diff --git a/.coverage b/.coverage index 7dffaa8..1ab8b16 100644 Binary files a/.coverage and b/.coverage differ diff --git a/docs/relations.md b/docs/relations.md index 43a6e20..3edf120 100644 --- a/docs/relations.md +++ b/docs/relations.md @@ -76,16 +76,16 @@ Since you join two times to the same table (categories) it won't work by default But don't worry - ormar can handle situations like this, as it uses the Relationship Manager which has it's aliases defined for all relationships. -Each class is registered with the same instance of the RelationshipManager that you can access like this: +Each class is registered with the same instance of the AliasManager that you can access like this: ```python -SchoolClass._orm_relationship_manager +SchoolClass.alias_manager ``` It's the same object for all `Models` ```python -print(Teacher._orm_relationship_manager == Student._orm_relationship_manager) +print(Teacher.alias_manager == Student.alias_manager) # will produce: True ``` @@ -94,11 +94,11 @@ print(Teacher._orm_relationship_manager == Student._orm_relationship_manager) You can even preview the alias used for any relation by passing two tables names. ```python -print(Teacher._orm_relationship_manager.resolve_relation_join( +print(Teacher.alias_manager.resolve_relation_join( 'students', 'categories')) # will produce: KId1c6 (sample value) -print(Teacher._orm_relationship_manager.resolve_relation_join( +print(Teacher.alias_manager.resolve_relation_join( 'categories', 'students')) # will produce: EFccd5 (sample value) ``` diff --git a/docs_src/models/docs006.py b/docs_src/models/docs006.py index b232979..368e771 100644 --- a/docs_src/models/docs006.py +++ b/docs_src/models/docs006.py @@ -36,6 +36,6 @@ print('department' in course.__dict__) # False <- related model is not stored on Course instance print(course.department) # Department(id=None, name='Science') <- Department model -# returned from RelationshipManager +# returned from AliasManager print(course.department.name) # Science \ No newline at end of file diff --git a/ormar/fields/base.py b/ormar/fields/base.py index a057920..126a3a6 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -64,6 +64,6 @@ class BaseField: @classmethod def expand_relationship( - cls, value: Any, child: Union["Model", "NewBaseModel"] + cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True ) -> Any: return value diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 9d052a7..ebd4ec6 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -54,6 +54,7 @@ def ForeignKey( class ForeignKeyField(BaseField): to: Type["Model"] + name: str related_name: str virtual: bool @@ -65,36 +66,35 @@ class ForeignKeyField(BaseField): def validate(cls, value: Any) -> Any: return value - # @property - # def __type__(self) -> Type[BaseModel]: - # return self.to.__pydantic_model__ - - # @classmethod - # def get_column_type(cls) -> sqlalchemy.Column: - # to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] - # return to_column.column_type - @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model" + cls, value: List, child: "Model", to_register: bool ) -> Union["Model", List["Model"]]: - return [cls.expand_relationship(val, child) for val in value] + return [cls.expand_relationship(val, child, to_register) for val in value] @classmethod - def _register_existing_model(cls, value: "Model", child: "Model") -> "Model": - cls.register_relation(value, child) + def _register_existing_model( + cls, value: "Model", child: "Model", to_register: bool + ) -> "Model": + if to_register: + cls.register_relation(value, child) return value @classmethod - def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model": + def _construct_model_from_dict( + cls, value: dict, child: "Model", to_register: bool + ) -> "Model": if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: value["__pk_only__"] = True model = cls.to(**value) - cls.register_relation(model, child) + if to_register: + cls.register_relation(model, child) return model @classmethod - def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model": + def _construct_model_from_pk( + cls, value: Any, child: "Model", to_register: bool + ) -> "Model": if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( f"Relationship error - ForeignKey {cls.to.__name__} " @@ -102,19 +102,19 @@ class ForeignKeyField(BaseField): f"while {type(value)} passed as a parameter." ) model = create_dummy_instance(fk=cls.to, pk=value) - cls.register_relation(model, child) + if to_register: + cls.register_relation(model, child) return model @classmethod def register_relation(cls, model: "Model", child: "Model") -> None: - child_model_name = cls.related_name or child.get_name() - model.Meta._orm_relationship_manager.add_relation( - model, child, child_model_name, virtual=cls.virtual + model._orm.add( + parent=model, child=child, child_name=cls.related_name, virtual=cls.virtual ) @classmethod def expand_relationship( - cls, value: Any, child: "Model" + cls, value: Any, child: "Model", to_register: bool = True ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None @@ -127,5 +127,5 @@ class ForeignKeyField(BaseField): model = constructors.get( value.__class__.__name__, cls._construct_model_from_pk - )(value, child) + )(value, child, to_register) return model diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index dd81838..8ff1dd7 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -57,7 +57,7 @@ class String(ModelFieldFactory): _bases = (pydantic.ConstrainedStr, BaseField) _type = str - def __new__( + def __new__( # noqa CFQ002 cls, *, allow_blank: bool = False, @@ -231,7 +231,7 @@ class Decimal(ModelFieldFactory): _bases = (pydantic.ConstrainedDecimal, BaseField) _type = decimal.Decimal - def __new__( + def __new__( # noqa CFQ002 cls, *, minimum: float = None, diff --git a/ormar/models/__init__.py b/ormar/models/__init__.py index 53eaff4..e6d8bd5 100644 --- a/ormar/models/__init__.py +++ b/ormar/models/__init__.py @@ -1,4 +1,5 @@ -from ormar.models.model import Model -from ormar.models.newbasemodel import NewBaseModel +from ormar.models.newbasemodel import NewBaseModel # noqa I100 +from ormar.models.model import Model # noqa I100 +from ormar.models.metaclass import expand_reverse_relationships # noqa I100 -__all__ = ["NewBaseModel", "Model"] +__all__ = ["NewBaseModel", "Model", "expand_reverse_relationships"] diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index aa65c74..8bc820d 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -10,12 +10,12 @@ from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.queryset import QuerySet -from ormar.relations import RelationshipManager +from ormar.relations import AliasManager if TYPE_CHECKING: # pragma no cover from ormar import Model -relationship_manager = RelationshipManager() +alias_manager = AliasManager() class ModelMeta: @@ -26,19 +26,19 @@ class ModelMeta: columns: List[sqlalchemy.Column] pkname: str model_fields: Dict[str, Union[BaseField, ForeignKey]] - _orm_relationship_manager: RelationshipManager + alias_manager: AliasManager -def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: - child_relation_name = ( - field.to.get_name(title=True) - + "_" - + (field.related_name or (name.lower() + "s")) - ) - reverse_name = child_relation_name - relation_name = name.lower().title() + "_" + field.to.get_name() - relationship_manager.add_relation_type( - relation_name, reverse_name, field, table_name +def register_relation_on_build(table_name: str, field: ForeignKey) -> None: + alias_manager.add_relation_type(field, table_name) + + +def reverse_field_not_already_registered( + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] +) -> bool: + return ( + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -48,9 +48,8 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: child_model_name = model_field.related_name or model.get_name() + "s" parent_model = model_field.to child = model - if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + if reverse_field_not_already_registered( + child, child_model_name, parent_model ): register_reverse_model_fields(parent_model, child, child_model_name) @@ -63,29 +62,42 @@ def register_reverse_model_fields( ) +def check_pk_column_validity( + field_name: str, field: BaseField, pkname: str +) -> Optional[str]: + if pkname is not None: + raise ModelDefinitionError("Only one primary key column is allowed.") + if field.pydantic_only: + raise ModelDefinitionError("Primary key column cannot be pydantic only") + return field_name + + def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str -) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: + model_fields: Dict, table_name: str +) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None - model_fields = { - field_name: field - for field_name, field in object_dict["__annotations__"].items() - if issubclass(field, BaseField) - } for field_name, field in model_fields.items(): if field.primary_key: - if pkname is not None: - raise ModelDefinitionError("Only one primary key column is allowed.") - if field.pydantic_only: - raise ModelDefinitionError("Primary key column cannot be pydantic only") - pkname = field_name + pkname = check_pk_column_validity(field_name, field, pkname) if not field.pydantic_only: columns.append(field.get_column(field_name)) if issubclass(field, ForeignKeyField): - register_relation_on_build(table_name, field, name) + register_relation_on_build(table_name, field) - return pkname, columns, model_fields + return pkname, columns + + +def populate_default_pydantic_field_value( + type_: Type[BaseField], field: str, attrs: dict +) -> dict: + def_value = type_.default_value() + curr_def_value = attrs.get(field, "NONE") + if curr_def_value == "NONE" and isinstance(def_value, FieldInfo): + attrs[field] = def_value + elif curr_def_value == "NONE" and type_.nullable: + attrs[field] = FieldInfo(default=None) + return attrs def populate_pydantic_default_values(attrs: Dict) -> Dict: @@ -93,20 +105,70 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict: if issubclass(type_, BaseField): if type_.name is None: type_.name = field - def_value = type_.default_value() - curr_def_value = attrs.get(field, "NONE") - if curr_def_value == "NONE" and isinstance(def_value, FieldInfo): - attrs[field] = def_value - elif curr_def_value == "NONE" and type_.nullable: - attrs[field] = FieldInfo(default=None) + attrs = populate_default_pydantic_field_value(type_, field, attrs) return attrs +def extract_annotations_and_module( + attrs: dict, new_model: "ModelMetaclass", bases: Tuple +) -> dict: + annotations = attrs.get("__annotations__") or new_model.__annotations__ + attrs["__annotations__"] = annotations + attrs = populate_pydantic_default_values(attrs) + + attrs["__module__"] = attrs["__module__"] or bases[0].__module__ + attrs["__annotations__"] = attrs["__annotations__"] or bases[0].__annotations__ + return attrs + + +def populate_meta_orm_model_fields( + attrs: dict, new_model: Type["Model"] +) -> Type["Model"]: + model_fields = { + field_name: field + for field_name, field in attrs["__annotations__"].items() + if issubclass(field, BaseField) + } + new_model.Meta.model_fields = model_fields + return new_model + + +def populate_meta_tablename_columns_and_pk( + name: str, new_model: Type["Model"] +) -> Type["Model"]: + tablename = name.lower() + "s" + new_model.Meta.tablename = new_model.Meta.tablename or tablename + + if hasattr(new_model.Meta, "columns"): + columns = new_model.Meta.table.columns + pkname = new_model.Meta.pkname + else: + pkname, columns = sqlalchemy_columns_from_model_fields( + new_model.Meta.model_fields, new_model.Meta.tablename + ) + new_model.Meta.columns = columns + new_model.Meta.pkname = pkname + + if not new_model.Meta.pkname: + raise ModelDefinitionError("Table has to have a primary key.") + + return new_model + + +def populate_meta_sqlalchemy_table_if_required( + new_model: Type["Model"], +) -> Type["Model"]: + if not hasattr(new_model.Meta, "table"): + new_model.Meta.table = sqlalchemy.Table( + new_model.Meta.tablename, new_model.Meta.metadata, *new_model.Meta.columns + ) + return new_model + + def get_pydantic_base_orm_config() -> Type[BaseConfig]: class Config(BaseConfig): orm_mode = True arbitrary_types_allowed = True - # extra = Extra.allow return Config @@ -121,44 +183,17 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): if hasattr(new_model, "Meta"): - annotations = attrs.get("__annotations__") or new_model.__annotations__ - attrs["__annotations__"] = annotations - attrs = populate_pydantic_default_values(attrs) + attrs = extract_annotations_and_module(attrs, new_model, bases) + new_model = populate_meta_orm_model_fields(attrs, new_model) + new_model = populate_meta_tablename_columns_and_pk(name, new_model) + new_model = populate_meta_sqlalchemy_table_if_required(new_model) + expand_reverse_relationships(new_model) - tablename = name.lower() + "s" - new_model.Meta.tablename = new_model.Meta.tablename or tablename - - # sqlalchemy table creation - - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, new_model.Meta.tablename - ) - - if hasattr(new_model.Meta, "model_fields") and not pkname: - model_fields = new_model.Meta.model_fields - for fieldname, field in new_model.Meta.model_fields.items(): - if field.primary_key: - pkname = fieldname - columns = new_model.Meta.table.columns - - if not hasattr(new_model.Meta, "table"): - new_model.Meta.table = sqlalchemy.Table( - new_model.Meta.tablename, new_model.Meta.metadata, *columns - ) - - new_model.Meta.columns = columns - new_model.Meta.pkname = pkname - - if not pkname: - raise ModelDefinitionError("Table has to have a primary key.") - - new_model.Meta.model_fields = model_fields new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) - expand_reverse_relationships(new_model) - new_model.Meta._orm_relationship_manager = relationship_manager + new_model.Meta.alias_manager = alias_manager new_model.objects = QuerySet(new_model) return new_model diff --git a/ormar/models/model.py b/ormar/models/model.py index 84c3404..1b40edb 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,4 +1,5 @@ -from typing import Any, List +import itertools +from typing import Any, List, Tuple, Union import sqlalchemy @@ -6,6 +7,21 @@ import ormar.queryset # noqa I100 from ormar.models import NewBaseModel # noqa I100 +def group_related_list(list_: List) -> dict: + test_dict = dict() + grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) + for key, group in grouped: + group_list = list(group) + new = [ + "__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1 + ] + if any("__" in x for x in new): + test_dict[key] = group_related_list(new) + else: + test_dict[key] = new + return test_dict + + class Model(NewBaseModel): __abstract__ = False @@ -14,22 +30,44 @@ class Model(NewBaseModel): cls, row: sqlalchemy.engine.ResultProxy, select_related: List = None, + related_models: Any = None, previous_table: str = None, - ) -> "Model": + ) -> Union["Model", Tuple["Model", dict]]: item = {} select_related = select_related or [] + related_models = related_models or [] + if select_related: + related_models = group_related_list(select_related) - table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join( + table_prefix = cls.Meta.alias_manager.resolve_relation_join( previous_table, cls.Meta.table.name ) + previous_table = cls.Meta.table.name - for related in select_related: - if "__" in related: - first_part, remainder = related.split("__", 1) + + item = cls.populate_nested_models_from_row( + item, row, related_models, previous_table + ) + item = cls.extract_prefixed_table_columns(item, row, table_prefix) + + instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None + return instance + + @classmethod + def populate_nested_models_from_row( + cls, + item: dict, + row: sqlalchemy.engine.ResultProxy, + related_models: Any, + previous_table: sqlalchemy.Table, + ) -> dict: + for related in related_models: + if isinstance(related_models, dict) and related_models[related]: + first_part, remainder = related, related_models[related] model_cls = cls.Meta.model_fields[first_part].to child = model_cls.from_row( - row, select_related=[remainder], previous_table=previous_table + row, related_models=remainder, previous_table=previous_table ) item[first_part] = child else: @@ -37,17 +75,23 @@ class Model(NewBaseModel): child = model_cls.from_row(row, previous_table=previous_table) item[related] = child + return item + + @classmethod + def extract_prefixed_table_columns( + cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str + ) -> dict: for column in cls.Meta.table.columns: if column.name not in item: item[column.name] = row[ f'{table_prefix + "_" if table_prefix else ""}{column.name}' ] - - return cls(**item) + return item async def save(self) -> "Model": self_fields = self._extract_model_db_fields() - if self.Meta.model_fields.get(self.Meta.pkname).autoincrement: + + if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement: self_fields.pop(self.Meta.pkname, None) expr = self.Meta.table.insert() expr = expr.values(**self_fields) @@ -55,20 +99,18 @@ class Model(NewBaseModel): setattr(self, self.Meta.pkname, item_id) return self - async def update(self, **kwargs: Any) -> int: + async def update(self, **kwargs: Any) -> "Model": if kwargs: new_values = {**self.dict(), **kwargs} self.from_dict(new_values) self_fields = self._extract_model_db_fields() self_fields.pop(self.Meta.pkname) - expr = ( - self.Meta.table.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.Meta.pkname)) - ) - result = await self.Meta.database.execute(expr) - return result + expr = self.Meta.table.update().values(**self_fields) + expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) + + await self.Meta.database.execute(expr) + return self async def delete(self) -> int: expr = self.Meta.table.delete() diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index d9b99f3..81e1ea6 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,6 +1,5 @@ -import copy import inspect -from typing import List, Set, TYPE_CHECKING +from typing import List, Optional, Set, TYPE_CHECKING import ormar from ormar.fields.foreign_key import ForeignKeyField @@ -24,15 +23,15 @@ class ModelTableProxy: @classmethod def substitute_models_with_pks(cls, model_dict: dict) -> dict: - model_dict = copy.deepcopy(model_dict) for field in cls._extract_related_names(): - if field in model_dict and model_dict.get(field) is not None: + field_value = model_dict.get(field, None) + if field_value is not None: target_field = cls.Meta.model_fields[field] target_pkname = target_field.to.Meta.pkname - if isinstance(model_dict.get(field), ormar.Model): - model_dict[field] = getattr(model_dict.get(field), target_pkname) + if isinstance(field_value, ormar.Model): + model_dict[field] = getattr(field_value, target_pkname) else: - model_dict[field] = model_dict.get(field).get(target_pkname) + model_dict[field] = field_value.get(target_pkname) return model_dict @classmethod @@ -43,6 +42,18 @@ class ModelTableProxy: related_names.add(name) return related_names + @classmethod + def _extract_db_related_names(cls) -> Set: + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if ( + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and not field.virtual + ): + related_names.add(name) + return related_names + @classmethod def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: if nested: @@ -62,18 +73,28 @@ class ModelTableProxy: self_fields = { k: v for k, v in self_fields.items() if k in self.Meta.table.columns } - for field in self._extract_related_names(): + for field in self._extract_db_related_names(): target_pk_name = self.Meta.model_fields[field].to.Meta.pkname - if getattr(self, field) is not None: - self_fields[field] = getattr(getattr(self, field), target_pk_name) + target_field = getattr(self, field) + self_fields[field] = getattr(target_field, target_pk_name, None) return self_fields + @staticmethod + def resolve_relation_name(item: "Model", related: "Model") -> Optional[str]: + for name, field in item.Meta.model_fields.items(): + if issubclass(field, ForeignKeyField): + # fastapi is creating clones of response model + # that's why it can be a subclass of the original model + # so we need to compare Meta too as this one is copied as is + if field.to == related.__class__ or field.to.Meta == related.Meta: + return name + @classmethod def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: merged_rows = [] for index, model in enumerate(result_rows): - if index > 0 and model.pk == result_rows[index - 1].pk: - result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + if index > 0 and model.pk == merged_rows[-1].pk: + merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) else: merged_rows.append(model) return merged_rows diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 932e2c2..af5b295 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -20,9 +20,10 @@ from pydantic import BaseModel import ormar # noqa I100 from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy -from ormar.relations import RelationshipManager +from ormar.relations import AliasManager, RelationsManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model @@ -34,7 +35,7 @@ if TYPE_CHECKING: # pragma no cover class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): - __slots__ = ("_orm_id", "_orm_saved") + __slots__ = ("_orm_id", "_orm_saved", "_orm") if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] @@ -45,7 +46,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass __tablename__: str __metadata__: sqlalchemy.MetaData __database__: databases.Database - _orm_relationship_manager: RelationshipManager + _orm_relationship_manager: AliasManager + _orm: RelationsManager Meta: ModelMeta # noinspection PyMissingConstructor @@ -53,13 +55,30 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_saved", False) + object.__setattr__( + self, + "_orm", + RelationsManager( + related_fields=[ + field + for name, field in self.Meta.model_fields.items() + if issubclass(field, ForeignKeyField) + ], + owner=self, + ), + ) pk_only = kwargs.pop("__pk_only__", False) if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") + # build the models to set them and validate but don't register kwargs = { k: self._convert_json( - k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps" + k, + self.Meta.model_fields[k].expand_relationship( + v, self, to_register=False + ), + "dumps", ) for k, v in kwargs.items() } @@ -71,17 +90,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__fields_set__", fields_set) - def __del__(self) -> None: - self.Meta._orm_relationship_manager.deregister(self) + # register the related models after initialization + for related in self._extract_related_names(): + self.Meta.model_fields[related].expand_relationship( + kwargs.get(related), self, to_register=True + ) def __setattr__(self, name: str, value: Any) -> None: - relation_key = self.get_name(title=True) + "_" + name if name in self.__slots__: object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) - elif self.Meta._orm_relationship_manager.contains(relation_key, self): - self.Meta.model_fields[name].expand_relationship(value, self) + elif name in self._orm: + model = self.Meta.model_fields[name].expand_relationship(value, self) + self.__dict__[name] = model else: value = ( self._convert_json(name, value, "dumps") @@ -91,28 +113,25 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass super().__setattr__(name, value) def __getattribute__(self, item: str) -> Any: - if item != "__fields__" and item in self.__fields__: - related = self._extract_related_model_instead_of_field(item) - if related: - return related - value = object.__getattribute__(self, item) + if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): + return object.__getattribute__(self, item) + elif item != "_extract_related_names" and item in self._extract_related_names(): + return self._extract_related_model_instead_of_field(item) + elif item == "pk": + return self.__dict__.get(self.Meta.pkname, None) + elif item != "__fields__" and item in self.__fields__: + value = self.__dict__.get(item, None) value = self._convert_json(item, value, "loads") return value return super().__getattribute__(item) - def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]: - return self._extract_related_model_instead_of_field(item) - def _extract_related_model_instead_of_field( self, item: str ) -> Optional[Union["Model", List["Model"]]]: - relation_key = self.get_name(title=True) + "_" + item - if self.Meta._orm_relationship_manager.contains(relation_key, self): - return self.Meta._orm_relationship_manager.get(relation_key, self) + if item in self._orm: + return self._orm.get(item) def __same__(self, other: "Model") -> bool: - if self.__class__ != other.__class__: # pragma no cover - return False return ( self._orm_id == other._orm_id or self.__dict__ == other.__dict__ @@ -124,14 +143,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass name = cls.__name__ if lower: name = name.lower() - if title: - name = name.title() return name - @property - def pk(self) -> Any: - return getattr(self, self.Meta.pkname) - @property def pk_column(self) -> sqlalchemy.Column: return self.Meta.table.primary_key.columns.values()[0] @@ -140,6 +153,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ + def remove(self, name: "Model") -> None: + self._orm.remove_parent(self, name) + def dict( # noqa A003 self, *, @@ -167,17 +183,25 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list): - dict_instance[field] = [x.dict(nested=True) for x in nested_model] + result = [] + for model in nested_model: + try: + result.append(model.dict(nested=True)) + except ReferenceError: # pragma no cover + continue + dict_instance[field] = result elif nested_model is not None: dict_instance[field] = nested_model.dict(nested=True) + else: + dict_instance[field] = None return dict_instance - def from_dict(self, value_dict: Dict) -> None: + def from_dict(self, value_dict: Dict) -> "Model": for key, value in value_dict.items(): setattr(self, key, value) + return self def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: - if not self._is_conversion_to_json_needed(column_name): return value diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index dc94e6f..f7436ac 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -37,13 +37,21 @@ class QueryClause: def filter( # noqa: A003 self, **kwargs: Any ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: - filter_clauses = self.filter_clauses - select_related = list(self._select_related) if kwargs.get("pk"): pk_name = self.model_cls.Meta.pkname kwargs[pk_name] = kwargs.pop("pk") + filter_clauses, select_related = self._populate_filter_clauses(**kwargs) + + return filter_clauses, select_related + + def _populate_filter_clauses( + self, **kwargs: Any + ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + for key, value in kwargs.items(): table_prefix = "" if "__" in key: @@ -73,24 +81,36 @@ class QueryClause: column = self.table.columns[key] table = self.table - value, has_escaped_character = self._escape_characters_in_clause(op, value) - - if isinstance(value, ormar.Model): - value = value.pk - - op_attr = FILTER_OPERATORS[op] - clause = getattr(column, op_attr)(value) - clause = self._compile_clause( - clause, - column, - table, - table_prefix, - modifiers={"escape": "\\" if has_escaped_character else None}, + clause = self._process_column_clause_for_operator_and_value( + value, op, column, table, table_prefix ) filter_clauses.append(clause) - return filter_clauses, select_related + def _process_column_clause_for_operator_and_value( + self, + value: Any, + op: str, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + ) -> sqlalchemy.sql.expression.TextClause: + value, has_escaped_character = self._escape_characters_in_clause(op, value) + + if isinstance(value, ormar.Model): + value = value.pk + + op_attr = FILTER_OPERATORS[op] + clause = getattr(column, op_attr)(value) + clause = self._compile_clause( + clause, + column, + table, + table_prefix, + modifiers={"escape": "\\" if has_escaped_character else None}, + ) + return clause + def _determine_filter_target_table( self, related_parts: List[str], select_related: List[str] ) -> Tuple[List[str], str, "Model"]: @@ -109,7 +129,7 @@ class QueryClause: previous_table = model_cls.Meta.tablename for part in related_parts: current_table = model_cls.Meta.model_fields[part].to.Meta.tablename - manager = model_cls.Meta._orm_relationship_manager + manager = model_cls.Meta.alias_manager table_prefix = manager.resolve_relation_join(previous_table, current_table) model_cls = model_cls.Meta.model_fields[part].to previous_table = current_table diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index b4c91c2..e216f07 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -5,8 +5,7 @@ from sqlalchemy import text import ormar # noqa I100 from ormar.fields.foreign_key import ForeignKeyField -from ormar.queryset.relationship_crawler import RelationshipCrawler -from ormar.relations import RelationshipManager +from ormar.relations import AliasManager if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -44,22 +43,19 @@ class Query: self.order_bys = None @property - def relation_manager(self) -> RelationshipManager: - return self.model_cls.Meta._orm_relationship_manager + def relation_manager(self) -> AliasManager: + return self.model_cls.Meta.alias_manager + + @property + def prefixed_pk_name(self) -> str: + return f"{self.table.name}.{self.model_cls.Meta.pkname}" def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] + self.order_bys = [text(self.prefixed_pk_name)] self.select_from = self.table - start_params = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) - - self._select_related = RelationshipCrawler().discover_relations( - self._select_related, prev_model=start_params.prev_model - ) - self._select_related.sort(key=lambda item: (-len(item), item)) + self._select_related.sort(key=lambda item: (item, -len(item))) for item in self._select_related: join_parameters = JoinParameters( @@ -77,10 +73,11 @@ class Query: # print(expr.compile(compile_kwargs={"literal_binds": True})) self._reset_query_parameters() - return expr, self._select_related + return expr + @staticmethod def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" @@ -92,7 +89,7 @@ class Query: model_cls = join_params.model_cls.Meta.model_fields[part].to to_table = model_cls.Meta.table.name - alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join( + alias = model_cls.Meta.alias_manager.resolve_relation_join( join_params.from_table, to_table ) if alias not in self.used_aliases: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 7ee3599..1b0db6d 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -47,7 +47,7 @@ class QuerySet: offset=self.query_offset, limit_count=self.limit_count, ) - exp, self._select_related = qry.build_select_expression() + exp = qry.build_select_expression() return exp def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 @@ -118,15 +118,25 @@ class QuerySet: async def get(self, **kwargs: Any) -> "Model": if kwargs: return await self.filter(**kwargs).get() + else: + if not self.filter_clauses: + expr = self.build_select_expression().limit(2) + else: + expr = self.build_select_expression() - expr = self.build_select_expression().limit(2) rows = await self.database.fetch_all(expr) + result_rows = [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + rows = self.model_cls.merge_instances_list(result_rows) + if not rows: raise NoMatch() if len(rows) > 1: raise MultipleMatches() - return self.model_cls.from_row(rows[0], select_related=self._select_related) + return rows[0] async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 if kwargs: @@ -138,7 +148,6 @@ class QuerySet: self.model_cls.from_row(row, select_related=self._select_related) for row in rows ] - result_rows = self.model_cls.merge_instances_list(result_rows) return result_rows diff --git a/ormar/queryset/relationship_crawler.py b/ormar/queryset/relationship_crawler.py deleted file mode 100644 index 7f8d055..0000000 --- a/ormar/queryset/relationship_crawler.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import List, TYPE_CHECKING, Type - -from ormar.fields import BaseField -from ormar.fields.foreign_key import ForeignKeyField - -if TYPE_CHECKING: # pragma no cover - from ormar import Model - - -class RelationshipCrawler: - def __init__(self) -> None: - self._select_related = [] - self.auto_related = [] - self.already_checked = [] - - def discover_relations( - self, select_related: List, prev_model: Type["Model"] - ) -> List[str]: - self._select_related = select_related - self._extract_auto_required_relations(prev_model=prev_model) - self._include_auto_related_models() - return self._select_related - - @staticmethod - def _field_is_a_foreign_key_and_no_circular_reference( - field: Type[BaseField], field_name: str, rel_part: str - ) -> bool: - return issubclass(field, ForeignKeyField) and field_name not in rel_part - - def _field_qualifies_to_deeper_search( - self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str - ) -> bool: - prev_part_of_related = "__".join(rel_part.split("__")[:-1]) - partial_match = any( - [x.startswith(prev_part_of_related) for x in self._select_related] - ) - already_checked = any( - [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] - ) - return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested - - def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, - ) -> None: - for field_name, field in prev_model.Meta.model_fields.items(): - if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part - ): - rel_part = field_name if not rel_part else rel_part + "__" + field_name - if not field.nullable: - if rel_part not in self._select_related: - split_tables = rel_part.split("__") - new_related = ( - "__".join(split_tables[:-1]) - if len(split_tables) > 1 - else rel_part - ) - self.auto_related.append(new_related) - rel_part = "" - elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part - ): - - self._extract_auto_required_relations( - prev_model=field.to, - rel_part=rel_part, - nested=True, - parent_virtual=field.virtual, - ) - else: - self.already_checked.append(rel_part) - rel_part = "" - - def _include_auto_related_models(self) -> None: - if self.auto_related: - new_joins = [] - for join in self._select_related: - if not any([x.startswith(join) for x in self.auto_related]): - new_joins.append(join) - self._select_related = new_joins + self.auto_related diff --git a/ormar/relations.py b/ormar/relations.py index d6b7616..d8c597e 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -1,26 +1,33 @@ -import pprint import string import uuid +from enum import Enum from random import choices -from typing import List, TYPE_CHECKING, Union +from typing import List, Optional, TYPE_CHECKING, Type, Union from weakref import proxy import sqlalchemy from sqlalchemy import text +import ormar # noqa I100 +from ormar.exceptions import RelationshipInstanceError # noqa I100 from ormar.fields.foreign_key import ForeignKeyField # noqa I100 + if TYPE_CHECKING: # pragma no cover - from ormar.models import NewBaseModel, Model + from ormar.models import Model def get_table_alias() -> str: return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] -class RelationshipManager: +class RelationType(Enum): + PRIMARY = 1 + REVERSE = 2 + + +class AliasManager: def __init__(self) -> None: - self._relations = dict() self._aliases = dict() @staticmethod @@ -34,86 +41,158 @@ class RelationshipManager: def prefixed_table_name(alias: str, name: str) -> text: return text(f"{name} {alias}_{name}") - def add_relation_type( - self, - relations_key: str, - reverse_key: str, - field: ForeignKeyField, - table_name: str, - ) -> None: - if relations_key not in self._relations: - self._relations[relations_key] = {"type": "primary"} + def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None: + if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases: self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() - if reverse_key not in self._relations: - self._relations[reverse_key] = {"type": "reverse"} + if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases: self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() - def deregister(self, model: "NewBaseModel") -> None: - for rel_type in self._relations.keys(): - if model.get_name() in rel_type.lower(): - if model._orm_id in self._relations[rel_type]: - del self._relations[rel_type][model._orm_id] - - def add_relation( - self, - parent: "NewBaseModel", - child: "NewBaseModel", - child_model_name: str, - virtual: bool = False, - ) -> None: - parent_id, child_id = parent._orm_id, child._orm_id - parent_name = parent.get_name(title=True) - child_name = ( - child_model_name - if child.get_name() != child_model_name - else child.get_name() + "s" - ) - if virtual: - child_name, parent_name = parent_name, child.get_name() - child_id, parent_id = parent_id, child_id - child, parent = parent, proxy(child) - child_name = child_name.lower() + "s" - else: - child = proxy(child) - - parent_relation_name = parent_name.title() + "_" + child_name - parents_list = self._relations[parent_relation_name].setdefault(parent_id, []) - self.append_related_model(parents_list, child) - - child_relation_name = child.get_name(title=True) + "_" + parent_name.lower() - children_list = self._relations[child_relation_name].setdefault(child_id, []) - self.append_related_model(children_list, parent) - - @staticmethod - def append_related_model(relations_list: List["Model"], model: "Model") -> None: - for relation_child in relations_list: - try: - if relation_child.__same__(model): - return - except ReferenceError: - continue - - relations_list.append(model) - - def contains(self, relations_key: str, instance: "NewBaseModel") -> bool: - if relations_key in self._relations: - return instance._orm_id in self._relations[relations_key] - return False - - def get( - self, relations_key: str, instance: "NewBaseModel" - ) -> Union["Model", List["Model"]]: - if relations_key in self._relations: - if instance._orm_id in self._relations[relations_key]: - if self._relations[relations_key]["type"] == "primary": - return self._relations[relations_key][instance._orm_id][0] - return self._relations[relations_key][instance._orm_id] - def resolve_relation_join(self, from_table: str, to_table: str) -> str: return self._aliases.get(f"{from_table}_{to_table}", "") - def __str__(self) -> str: # pragma no cover - return pprint.pformat(self._relations, indent=4, width=1) + +class RelationProxy(list): + def __init__(self, relation: "Relation") -> None: + super(RelationProxy, self).__init__() + self.relation = relation + self._owner = self.relation.manager.owner + + def remove(self, item: "Model") -> None: + super().remove(item) + rel_name = item.resolve_relation_name(item, self._owner) + item._orm._get(rel_name).remove(self._owner) + + def append(self, item: "Model") -> None: + super().append(item) + + def add(self, item: "Model") -> None: + rel_name = item.resolve_relation_name(item, self._owner) + setattr(item, rel_name, self._owner) + + +class Relation: + def __init__(self, manager: "RelationsManager", type_: RelationType) -> None: + self.manager = manager + self._owner = manager.owner + self._type = type_ + self.related_models = ( + RelationProxy(relation=self) if type_ == RelationType.REVERSE else None + ) + + def _find_existing(self, child: "Model") -> Optional[int]: + for ind, relation_child in enumerate(self.related_models[:]): + try: + if relation_child.__same__(child): + return ind + except ReferenceError: # pragma no cover + self.related_models.pop(ind) + return None + + def add(self, child: "Model") -> None: + relation_name = self._owner.resolve_relation_name(self._owner, child) + if self._type == RelationType.PRIMARY: + self.related_models = child + self._owner.__dict__[relation_name] = child + else: + if self._find_existing(child) is None: + self.related_models.append(child) + rel = self._owner.__dict__.get(relation_name, []) + rel.append(child) + self._owner.__dict__[relation_name] = rel + + def remove(self, child: "Model") -> None: + relation_name = self._owner.resolve_relation_name(self._owner, child) + if self._type == RelationType.PRIMARY: + if self.related_models.__same__(child): + self.related_models = None + del self._owner.__dict__[relation_name] + else: + position = self._find_existing(child) + if position is not None: + self.related_models.pop(position) + del self._owner.__dict__[relation_name][position] + + def get(self) -> Union[List["Model"], "Model"]: + return self.related_models def __repr__(self) -> str: # pragma no cover - return self.__str__() + return str(self.related_models) + + +class RelationsManager: + def __init__( + self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None + ) -> None: + self.owner = owner + self._related_fields = related_fields or [] + self._related_names = [field.name for field in self._related_fields] + self._relations = dict() + for field in self._related_fields: + self._add_relation(field) + + def _add_relation(self, field: Type[ForeignKeyField]) -> None: + self._relations[field.name] = Relation( + manager=self, + type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE, + ) + + def __contains__(self, item: str) -> bool: + return item in self._related_names + + def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]: + relation = self._relations.get(name, None) + if relation: + return relation.get() + + def _get(self, name: str) -> Optional[Relation]: + relation = self._relations.get(name, None) + if relation: + return relation + + @staticmethod + def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: + to_field = next( + ( + field + for field in child._orm._related_fields + if field.to == parent.__class__ or field.to.Meta == parent.Meta + ), + None, + ) + + if not to_field: # pragma no cover + raise RelationshipInstanceError( + f"Model {child.__class__} does not have " + f"reference to model {parent.__class__}" + ) + + to_name = to_field.name + if virtual: + child_name, to_name = to_name, child_name or child.get_name() + child, parent = parent, proxy(child) + else: + child_name = child_name or child.get_name() + "s" + child = proxy(child) + + parent_relation = parent._orm._get(child_name) + if not parent_relation: + ormar.models.expand_reverse_relationships(child.__class__) + name = parent.resolve_relation_name(parent, child) + field = parent.Meta.model_fields[name] + parent._orm._add_relation(field) + parent_relation = parent._orm._get(child_name) + parent_relation.add(child) + child._orm._get(to_name).add(parent) + + def remove(self, name: str, child: "Model") -> None: + relation = self._get(name) + relation.remove(child) + + @staticmethod + def remove_parent(item: "Model", name: Union[str, "Model"]) -> None: + related_model = name + name = item.resolve_relation_name(item, related_model) + if name in item._orm: + relation_name = item.resolve_relation_name(related_model, item) + item._orm.remove(name, related_model) + related_model._orm.remove(relation_name, item) diff --git a/tests/test_columns.py b/tests/test_columns.py index c8c9d3b..15382b0 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -22,7 +22,7 @@ class Example(ormar.Model): database = database id: ormar.Integer(primary_key=True) - name: ormar.String(max_length=200, default='aaa') + name: ormar.String(max_length=200, default="aaa") created: ormar.DateTime(default=datetime.datetime.now) created_day: ormar.Date(default=datetime.date.today) created_time: ormar.Time(default=time) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 6463eb6..fb85fc7 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,11 +1,11 @@ +import gc + import databases import pytest import sqlalchemy -from pydantic import ValidationError import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError -from ormar.fields.foreign_key import ForeignKeyField from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -131,9 +131,11 @@ async def test_model_crud(): album1 = await Album.objects.get(name="Malibu") assert album1.pk == 1 - assert album1.tracks is None + assert album1.tracks == [] - await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4) + await Track.objects.create( + album={"id": track.album.pk}, title="The Bird2", position=4 + ) @pytest.mark.asyncio @@ -164,6 +166,47 @@ async def test_select_related(): assert len(tracks) == 6 +@pytest.mark.asyncio +async def test_model_removal_from_relations(): + async with database: + album = Album(name="Chichi") + await album.save() + track1 = Track(album=album, title="The Birdman", position=1) + track2 = Track(album=album, title="Superman", position=2) + track3 = Track(album=album, title="Wonder Woman", position=3) + await track1.save() + await track2.save() + await track3.save() + + assert len(album.tracks) == 3 + album.tracks.remove(track1) + assert len(album.tracks) == 2 + assert track1.album is None + + await track1.update() + track1 = await Track.objects.get(title="The Birdman") + assert track1.album is None + + album.tracks.add(track1) + assert len(album.tracks) == 3 + assert track1.album == album + + await track1.update() + track1 = await Track.objects.select_related("album__tracks").get( + title="The Birdman" + ) + album = await Album.objects.select_related("tracks").get(name="Chichi") + assert track1.album == album + + track1.remove(album) + assert track1.album is None + assert len(album.tracks) == 2 + + track2.remove(album) + assert track2.album is None + assert len(album.tracks) == 1 + + @pytest.mark.asyncio async def test_fk_filter(): async with database: @@ -182,8 +225,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -191,8 +234,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -234,8 +277,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -254,8 +297,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index ebb0619..ab2845c 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -54,7 +54,9 @@ class ExampleModel2(Model): @pytest.fixture() def example(): - return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)) + return ExampleModel( + pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5) + ) def test_not_nullable_field_is_required(): @@ -110,6 +112,7 @@ def test_sqlalchemy_table_is_created(example): def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example3" @@ -120,6 +123,7 @@ def test_no_pk_in_model_definition(): def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example3" @@ -131,6 +135,7 @@ def test_two_pks_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example4" @@ -141,6 +146,7 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example5" @@ -151,6 +157,7 @@ def test_decimal_error_in_model_definition(): def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): class Meta: tablename = "example6" diff --git a/tests/test_models.py b/tests/test_models.py index f21e70c..1c00ef3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -28,7 +28,7 @@ class User(ormar.Model): database = database id: ormar.Integer(primary_key=True) - name: ormar.String(max_length=100, default='') + name: ormar.String(max_length=100, default="") class Product(ormar.Model): diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index 31e31b9..e01ce44 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -79,7 +79,7 @@ async def create_category(category: Category): @app.put("/items/{item_id}") async def get_item(item_id: int, item: Item): item_db = await Item.objects.get(pk=item_id) - return {"updated_rows": await item_db.update(**item.dict())} + return await item_db.update(**item.dict()) @app.delete("/items/{item_id}") @@ -105,7 +105,7 @@ def test_all_endpoints(): item.name = "New name" response = client.put(f"/items/{item.pk}", json=item.dict()) - assert response.json().get("updated_rows") == 1 + assert response.json() == item.dict() response = client.get("/items/") items = [Item(**item) for item in response.json()] diff --git a/tests/test_more_same_table_joins.py b/tests/test_more_same_table_joins.py new file mode 100644 index 0000000..3718bc0 --- /dev/null +++ b/tests/test_more_same_table_joins.py @@ -0,0 +1,110 @@ +import asyncio + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Department(ormar.Model): + class Meta: + tablename = "departments" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True, autoincrement=False) + name: ormar.String(max_length=100) + + +class SchoolClass(ormar.Model): + class Meta: + tablename = "schoolclasses" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + department: ormar.ForeignKey(Department, nullable=False) + + +class Student(ormar.Model): + class Meta: + tablename = "students" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) + + +class Teacher(ormar.Model): + class Meta: + tablename = "teachers" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + department = await Department.objects.create(id=1, name="Math Department") + department2 = await Department.objects.create(id=2, name="Law Department") + class1 = await SchoolClass.objects.create(name="Math") + class2 = await SchoolClass.objects.create(name="Logic") + category = await Category.objects.create(name="Foreign", department=department) + category2 = await Category.objects.create(name="Domestic", department=department2) + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Judy", category=category2, schoolclass=class1) + await Student.objects.create(name="Jack", category=category2, schoolclass=class2) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_multiple_instances_of_same_table_in_schema(): + async with database: + classes = await SchoolClass.objects.select_related( + ["teachers__category__department", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[0].name == "Jane" + assert len(classes[0].dict().get("students")) == 2 + assert classes[0].teachers[0].category.department.name == "Law Department" + + assert classes[0].students[0].category.pk is not None + assert classes[0].students[0].category.name is None + await classes[0].students[0].category.load() + await classes[0].students[0].category.department.load() + assert classes[0].students[0].category.department.name == "Math Department" diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 13e2185..5b8ffd0 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -79,11 +79,14 @@ async def create_test_database(): metadata.drop_all(engine) metadata.create_all(engine) department = await Department.objects.create(id=1, name="Math Department") + department2 = await Department.objects.create(id=2, name="Law Department") class1 = await SchoolClass.objects.create(name="Math", department=department) + class2 = await SchoolClass.objects.create(name="Logic", department=department2) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Student.objects.create(name="Jack", category=category2, schoolclass=class1) + await Student.objects.create(name="Judy", category=category2, schoolclass=class1) + await Student.objects.create(name="Jack", category=category2, schoolclass=class2) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) yield metadata.drop_all(engine) @@ -100,15 +103,15 @@ async def test_model_multiple_instances_of_same_table_in_schema(): assert len(classes[0].dict().get("students")) == 2 - # related fields of main model are only populated by pk - # unless there is a required foreign key somewhere along the way - # since department is required for schoolclass it was pre loaded (again) - # but you can load them anytime + # since it's going from schoolclass => teacher => schoolclass (same class) department is already populated assert classes[0].students[0].schoolclass.name == "Math" assert classes[0].students[0].schoolclass.department.name is None await classes[0].students[0].schoolclass.department.load() assert classes[0].students[0].schoolclass.department.name == "Math Department" + await classes[1].students[0].schoolclass.department.load() + assert classes[1].students[0].schoolclass.department.name == "Law Department" + @pytest.mark.asyncio async def test_right_tables_join(): @@ -130,5 +133,7 @@ async def test_multiple_reverse_related_objects(): ["teachers__category", "students__category"] ).all() assert classes[0].name == "Math" - assert classes[0].students[1].name == "Jack" + assert classes[0].students[1].name == "Judy" + assert classes[0].students[0].category.name == "Foreign" + assert classes[0].students[1].category.name == "Domestic" assert classes[0].teachers[0].category.name == "Domestic"