From 8b794d07f98aa12aa9e3e622cfeb77de60e21c2a Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 8 Jan 2021 18:19:26 +0100 Subject: [PATCH] WIP working self fk, adjusting m2m to work with self ref --- ormar/fields/base.py | 15 +++ ormar/fields/foreign_key.py | 98 +++++++++++++++---- ormar/fields/many_to_many.py | 66 ++++++++++--- ormar/models/helpers/models.py | 25 ++++- ormar/models/helpers/relations.py | 44 ++++++--- ormar/models/helpers/sqlalchemy.py | 92 ++++++++++++++---- ormar/models/metaclass.py | 1 + ormar/models/newbasemodel.py | 78 +++++++++++++-- ormar/queryset/queryset.py | 9 +- ormar/relations/alias_manager.py | 1 - tests/test_forward_refs.py | 146 +++++++++++++++++++++++++++++ 11 files changed, 507 insertions(+), 68 deletions(-) create mode 100644 tests/test_forward_refs.py diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 704b88b..fee98af 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -40,8 +40,10 @@ class BaseField(FieldInfo): pydantic_only: bool virtual: bool = False choices: typing.Sequence + to: Type["Model"] through: Type["Model"] + self_reference: bool = False default: Any server_default: Any @@ -263,3 +265,16 @@ class BaseField(FieldInfo): :rtype: Any """ return value + + @classmethod + def evaluate_forward_ref(cls, globalns: Any, localns: Any) -> None: + """ + Evaluates the ForwardRef to actual Field based on global and local namespaces + + :param globalns: global namespace + :type globalns: Any + :param localns: local namespace + :type localns: Any + :return: None + :rtype: None + """ diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 94a93d5..d8b5d8a 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,8 +1,9 @@ import uuid from dataclasses import dataclass -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, ForwardRef, List, Optional, TYPE_CHECKING, Tuple, Type, Union from pydantic import BaseModel, create_model +from pydantic.typing import evaluate_forwardref from sqlalchemy import UniqueConstraint import ormar # noqa I101 @@ -66,6 +67,43 @@ def create_dummy_model( return dummy_model +def populate_fk_params_based_on_to_model( + to: Type["Model"], nullable: bool, onupdate: str = None, ondelete: str = None, +) -> Tuple[Any, List, Any]: + """ + Based on target to model to which relation leads to populates the type of the + pydantic field to use, ForeignKey constraint and type of the target column field. + + :param to: target related ormar Model + :type to: Model class + :param nullable: marks field as optional/ required + :type nullable: bool + :param onupdate: parameter passed to sqlalchemy.ForeignKey. + How to treat child rows on update of parent (the one where FK is defined) model. + :type onupdate: str + :param ondelete: parameter passed to sqlalchemy.ForeignKey. + How to treat child rows on delete of parent (the one where FK is defined) model. + :type ondelete: str + :return: tuple with target pydantic type, list of fk constraints and target col type + :rtype: Tuple[Any, List, Any] + """ + fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname) + to_field = to.Meta.model_fields[to.Meta.pkname] + pk_only_model = create_dummy_model(to, to_field) + __type__ = ( + Union[to_field.__type__, to, pk_only_model] + if not nullable + else Optional[Union[to_field.__type__, to, pk_only_model]] + ) + constraints = [ + ForeignKeyConstraint( + name=fk_string, ondelete=ondelete, onupdate=onupdate # type: ignore + ) + ] + column_type = to_field.column_type + return __type__, constraints, column_type + + class UniqueColumns(UniqueConstraint): """ Subclass of sqlalchemy.UniqueConstraint. @@ -86,7 +124,7 @@ class ForeignKeyConstraint: def ForeignKey( # noqa CFQ002 - to: Type["Model"], + to: Union[Type["Model"], ForwardRef], *, name: str = None, unique: bool = False, @@ -127,27 +165,26 @@ def ForeignKey( # noqa CFQ002 :return: ormar ForeignKeyField with relation to selected model :rtype: ForeignKeyField """ - fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname) - to_field = to.Meta.model_fields[to.Meta.pkname] - pk_only_model = create_dummy_model(to, to_field) - __type__ = ( - Union[to_field.__type__, to, pk_only_model] - if not nullable - else Optional[Union[to_field.__type__, to, pk_only_model]] - ) + + if isinstance(to, ForwardRef): + __type__ = to if not nullable else Optional[to] + constraints: List = [] + column_type = None + else: + __type__, constraints, column_type = populate_fk_params_based_on_to_model( + to=to, nullable=nullable, ondelete=ondelete, onupdate=onupdate + ) + namespace = dict( __type__=__type__, to=to, + through=None, alias=name, name=kwargs.pop("real_name", None), nullable=nullable, - constraints=[ - ForeignKeyConstraint( - name=fk_string, ondelete=ondelete, onupdate=onupdate # type: ignore - ) - ], + constraints=constraints, unique=unique, - column_type=to_field.column_type, + column_type=column_type, related_name=related_name, virtual=virtual, primary_key=False, @@ -155,6 +192,8 @@ def ForeignKey( # noqa CFQ002 pydantic_only=False, default=None, server_default=None, + onupdate=onupdate, + ondelete=ondelete, ) return type("ForeignKey", (ForeignKeyField, BaseField), namespace) @@ -169,6 +208,33 @@ class ForeignKeyField(BaseField): name: str related_name: str virtual: bool + ondelete: str + onupdate: str + + @classmethod + def evaluate_forward_ref(cls, globalns: Any, localns: Any) -> None: + """ + Evaluates the ForwardRef to actual Field based on global and local namespaces + + :param globalns: global namespace + :type globalns: Any + :param localns: local namespace + :type localns: Any + :return: None + :rtype: None + """ + if isinstance(cls.to, ForwardRef): + cls.to = evaluate_forwardref(cls.to, globalns, localns or None) + ( + cls.__type__, + cls.constraints, + cls.column_type, + ) = populate_fk_params_based_on_to_model( + to=cls.to, + nullable=cls.nullable, + ondelete=cls.ondelete, + onupdate=cls.onupdate, + ) @classmethod def _extract_model_from_sequence( diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 5039bfd..4a7161f 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -1,4 +1,6 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, ForwardRef, List, Optional, TYPE_CHECKING, Tuple, Type, Union + +from pydantic.typing import evaluate_forwardref import ormar from ormar.fields import BaseField @@ -10,6 +12,30 @@ if TYPE_CHECKING: # pragma no cover REF_PREFIX = "#/components/schemas/" +def populate_m2m_params_based_on_to_model( + to: Type["Model"], nullable: bool +) -> Tuple[List, Any]: + """ + Based on target to model to which relation leads to populates the type of the + pydantic field to use and type of the target column field. + + :param to: target related ormar Model + :type to: Model class + :param nullable: marks field as optional/ required + :type nullable: bool + :return: Tuple[List, Any] + :rtype: tuple with target pydantic type and target col type + """ + to_field = to.Meta.model_fields[to.Meta.pkname] + __type__ = ( + Union[to_field.__type__, to, List[to]] # type: ignore + if not nullable + else Optional[Union[to_field.__type__, to, List[to]]] # type: ignore + ) + column_type = to_field.column_type + return __type__, column_type + + def ManyToMany( to: Type["Model"], through: Type["Model"], @@ -42,23 +68,25 @@ def ManyToMany( :return: ormar ManyToManyField with m2m relation to selected model :rtype: ManyToManyField """ - to_field = to.Meta.model_fields[to.Meta.pkname] related_name = kwargs.pop("related_name", None) nullable = kwargs.pop("nullable", True) - __type__ = ( - Union[to_field.__type__, to, List[to]] # type: ignore - if not nullable - else Optional[Union[to_field.__type__, to, List[to]]] # type: ignore - ) + + if isinstance(to, ForwardRef): + __type__ = to if not nullable else Optional[to] + column_type = None + else: + __type__, column_type = populate_m2m_params_based_on_to_model( + to=to, nullable=nullable + ) namespace = dict( __type__=__type__, to=to, through=through, alias=name, name=name, - nullable=True, + nullable=nullable, unique=unique, - column_type=to_field.column_type, + column_type=column_type, related_name=related_name, virtual=virtual, primary_key=False, @@ -76,8 +104,6 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro Actual class returned from ManyToMany function call and stored in model_fields. """ - through: Type["Model"] - @classmethod def default_target_field_name(cls) -> str: """ @@ -86,3 +112,21 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro :rtype: str """ return cls.to.get_name() + + @classmethod + def evaluate_forward_ref(cls, globalns: Any, localns: Any) -> None: + """ + Evaluates the ForwardRef to actual Field based on global and local namespaces + + :param globalns: global namespace + :type globalns: Any + :param localns: local namespace + :type localns: Any + :return: None + :rtype: None + """ + if isinstance(cls.to, ForwardRef) or isinstance(cls.through, ForwardRef): + cls.to = evaluate_forwardref(cls.to, globalns, localns or None) + (cls.__type__, cls.column_type,) = populate_m2m_params_based_on_to_model( + to=cls.to, nullable=cls.nullable, + ) diff --git a/ormar/models/helpers/models.py b/ormar/models/helpers/models.py index 4c899d3..d2e3ef3 100644 --- a/ormar/models/helpers/models.py +++ b/ormar/models/helpers/models.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Type +from typing import Dict, ForwardRef, List, Optional, TYPE_CHECKING, Tuple, Type import ormar from ormar.fields.foreign_key import ForeignKeyField @@ -6,6 +6,22 @@ from ormar.models.helpers.pydantic import populate_pydantic_default_values if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.fields import BaseField + + +def is_field_an_forward_ref(field: Type["BaseField"]) -> bool: + """ + Checks if field is a relation field and whether any of the referenced models + are ForwardRefs that needs to be updated before proceeding. + + :param field: model field to verify + :type field: Type[BaseField] + :return: result of the check + :rtype: bool + """ + return issubclass(field, ForeignKeyField) and ( + isinstance(field.to, ForwardRef) or isinstance(field.through, ForwardRef) + ) def populate_default_options_values( @@ -33,6 +49,13 @@ def populate_default_options_values( if not hasattr(new_model.Meta, "abstract"): new_model.Meta.abstract = False + if any( + is_field_an_forward_ref(field) for field in new_model.Meta.model_fields.values() + ): + new_model.Meta.requires_ref_update = True + else: + new_model.Meta.requires_ref_update = False + def extract_annotations_and_default_vals(attrs: Dict) -> Tuple[Dict, Dict]: """ diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index 4792531..60182f9 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Type +from typing import ForwardRef, TYPE_CHECKING, Type import ormar from ormar import ForeignKey, ManyToMany @@ -61,6 +61,28 @@ def register_many_to_many_relation_on_build( ) +def expand_reverse_relationship( + model: Type["Model"], model_field: Type["ForeignKeyField"] +) -> None: + """ + If the reverse relation has not been set before it's set here. + + :param model: model on which relation should be checked and registered + :type model: Model class + :param model_field: + :type model_field: + :return: None + :rtype: None + """ + child_model_name = model_field.related_name or model.get_name() + "s" + parent_model = model_field.to + child = model + if reverse_field_not_already_registered(child, child_model_name, parent_model): + register_reverse_model_fields( + parent_model, child, child_model_name, model_field + ) + + def expand_reverse_relationships(model: Type["Model"]) -> None: """ Iterates through model_fields of given model and verifies if all reverse @@ -72,16 +94,12 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: :type model: Model class """ for model_field in model.Meta.model_fields.values(): - if issubclass(model_field, ForeignKeyField): - child_model_name = model_field.related_name or model.get_name() + "s" - parent_model = model_field.to - child = model - if reverse_field_not_already_registered( - child, child_model_name, parent_model - ): - register_reverse_model_fields( - parent_model, child, child_model_name, model_field - ) + if ( + issubclass(model_field, ForeignKeyField) + and not isinstance(model_field.to, ForwardRef) + and not isinstance(model_field.through, ForwardRef) + ): + expand_reverse_relationship(model=model, model_field=model_field) def register_reverse_model_fields( @@ -142,10 +160,14 @@ def register_relation_in_alias_manager( :type field_name: str """ if issubclass(field, ManyToManyField): + if isinstance(field.to, ForwardRef) or isinstance(field.through, ForwardRef): + return register_many_to_many_relation_on_build( new_model=new_model, field=field, field_name=field_name ) elif issubclass(field, ForeignKeyField): + if isinstance(field.to, ForwardRef): + return register_relation_on_build(new_model=new_model, field_name=field_name) diff --git a/ormar/models/helpers/sqlalchemy.py b/ormar/models/helpers/sqlalchemy.py index 37cdaa9..ede9dd4 100644 --- a/ormar/models/helpers/sqlalchemy.py +++ b/ormar/models/helpers/sqlalchemy.py @@ -1,20 +1,22 @@ import copy import logging -from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Type +from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union import sqlalchemy from ormar import ForeignKey, Integer, ModelDefinitionError # noqa: I202 from ormar.fields import BaseField, ManyToManyField +from ormar.fields.foreign_key import ForeignKeyField from ormar.models.helpers.models import validate_related_names_in_relations from ormar.models.helpers.pydantic import create_pydantic_field if TYPE_CHECKING: # pragma no cover from ormar import Model, ModelMeta + from ormar.models import NewBaseModel def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: """ Registers m2m relation on through model. @@ -29,22 +31,36 @@ def adjust_through_many_to_many_model( :param model_field: relation field defined in parent model :type model_field: ManyToManyField """ - model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( - model, real_name=model.get_name(), ondelete="CASCADE" + same_table_ref = False + if child == model or child.Meta == model.Meta: + same_table_ref = True + model_field.self_reference = True + + if same_table_ref: + parent_name = f'to_{model.get_name()}' + child_name = f'from_{child.get_name()}' + else: + parent_name = model.get_name() + child_name = child.get_name() + + model_field.through.Meta.model_fields[parent_name] = ForeignKey( + model, real_name=parent_name, ondelete="CASCADE" ) - model_field.through.Meta.model_fields[child.get_name()] = ForeignKey( - child, real_name=child.get_name(), ondelete="CASCADE" + model_field.through.Meta.model_fields[child_name] = ForeignKey( + child, real_name=child_name, ondelete="CASCADE" ) - create_and_append_m2m_fk(model, model_field) - create_and_append_m2m_fk(child, model_field) + create_and_append_m2m_fk(model=model, model_field=model_field, + field_name=parent_name) + create_and_append_m2m_fk(model=child, model_field=model_field, + field_name=child_name) - create_pydantic_field(model.get_name(), model, model_field) - create_pydantic_field(child.get_name(), child, model_field) + create_pydantic_field(parent_name, model, model_field) + create_pydantic_field(child_name, child, model_field) def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField], field_name: str ) -> None: """ Registers sqlalchemy Column with sqlalchemy.ForeignKey leadning to the model. @@ -63,7 +79,7 @@ def create_and_append_m2m_fk( "ManyToMany relation cannot lead to field without pk" ) column = sqlalchemy.Column( - model.get_name(), + field_name, pk_column.type, sqlalchemy.schema.ForeignKey( model.Meta.tablename + "." + pk_alias, @@ -72,12 +88,11 @@ def create_and_append_m2m_fk( ), ) model_field.through.Meta.columns.append(column) - # breakpoint() model_field.through.Meta.table.append_column(copy.deepcopy(column)) def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: """ Receives the field marked as primary key and verifies if the pkname @@ -102,7 +117,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, new_model: Type["Model"] + model_fields: Dict, new_model: Type["Model"] ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: """ Iterates over declared on Model model fields and extracts fields that @@ -143,16 +158,16 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field.get_alias())) return pkname, columns def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: """ Sets Model tablename if it's not already set in Meta. @@ -194,6 +209,20 @@ def populate_meta_tablename_columns_and_pk( return new_model +def check_for_null_type_columns_from_forward_refs(meta: "ModelMeta") -> bool: + """ + Check is any column is of NUllType() meaning it's empty column from ForwardRef + + :param meta: Meta class of the Model without sqlalchemy table constructed + :type meta: Model class Meta + :return: result of the check + :rtype: bool + """ + return not any( + isinstance(col.type, sqlalchemy.sql.sqltypes.NullType) for col in meta.columns + ) + + def populate_meta_sqlalchemy_table_if_required(meta: "ModelMeta") -> None: """ Constructs sqlalchemy table out of columns and parameters set on Meta class. @@ -204,10 +233,33 @@ def populate_meta_sqlalchemy_table_if_required(meta: "ModelMeta") -> None: :return: class with populated Meta.table :rtype: Model class """ - if not hasattr(meta, "table"): + if not hasattr(meta, "table") and check_for_null_type_columns_from_forward_refs( + meta + ): meta.table = sqlalchemy.Table( meta.tablename, meta.metadata, *[copy.deepcopy(col) for col in meta.columns], *meta.constraints, ) + + +def update_column_definition( + model: Union[Type["Model"], Type["NewBaseModel"]], field: Type[ForeignKeyField] +) -> None: + """ + Updates a column with a new type column based on updated parameters in FK fields. + + :param model: model on which columns needs to be updated + :type model: Type["Model"] + :param field: field with column definition that requires update + :type field: Type[ForeignKeyField] + :return: None + :rtype: None + """ + columns = model.Meta.columns + for ind, column in enumerate(columns): + if column.name == field.get_alias(): + new_column = field.get_column(field.get_alias()) + columns[ind] = new_column + break diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 91340f0..f808e60 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -66,6 +66,7 @@ class ModelMeta: property_fields: Set signals: SignalEmitter abstract: bool + requires_ref_update: bool def check_if_field_has_choices(field: Type[BaseField]) -> bool: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 6a3ea7a..f79d53c 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -1,7 +1,4 @@ -try: - import orjson as json -except ImportError: # pragma: no cover - import json # type: ignore +import sys import uuid from typing import ( AbstractSet, @@ -20,6 +17,12 @@ from typing import ( Union, ) +try: + import orjson as json +except ImportError: # pragma: no cover + import json # type: ignore + + import databases import pydantic import sqlalchemy @@ -28,6 +31,13 @@ from pydantic import BaseModel import ormar # noqa I100 from ormar.exceptions import ModelError from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField +from ormar.models.helpers import register_relation_in_alias_manager +from ormar.models.helpers.relations import expand_reverse_relationship +from ormar.models.helpers.sqlalchemy import ( + populate_meta_sqlalchemy_table_if_required, + update_column_definition, +) from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy from ormar.queryset.utils import translate_list_to_dict @@ -103,14 +113,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass should be explicitly set to None, as otherwise pydantic will try to populate them with their default values if default is set. - :raises ModelError: if abstract model is initialized or unknown field is passed + :raises ModelError: if abstract model is initialized, model has ForwardRefs + that has not been updated or unknown field is passed :param args: ignored args :type args: Any :param kwargs: keyword arguments - all fields values and some special params :type kwargs: Any """ - if self.Meta.abstract: - raise ModelError(f"You cannot initialize abstract model {self.get_name()}") + self._verify_model_can_be_initialized() object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_saved", False) object.__setattr__(self, "_pk_column", None) @@ -265,6 +275,22 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass return value return object.__getattribute__(self, item) # pragma: no cover + def _verify_model_can_be_initialized(self) -> None: + """ + Raises exception if model is abstract or has ForwardRefs in relation fields. + + :return: None + :rtype: None + """ + if self.Meta.abstract: + raise ModelError(f"You cannot initialize abstract model {self.get_name()}") + if self.Meta.requires_ref_update: + raise ModelError( + f"Model {self.get_name()} has not updated " + f"ForwardRefs. \nBefore using the model you " + f"need to call update_forward_refs()." + ) + def _extract_related_model_instead_of_field( self, item: str ) -> Optional[Union["T", Sequence["T"]]]: @@ -398,6 +424,44 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass props = {prop for prop in props if prop not in exclude} return props + @classmethod + def update_forward_refs(cls, **localns: Any) -> None: + """ + Processes fields that are ForwardRef and need to be evaluated into actual + models. + + Expands relationships, register relation in alias manager and substitutes + sqlalchemy columns with new ones with proper column type (null before). + + Populates Meta table of the Model which is left empty before. + + Calls the pydantic method to evaluate pydantic fields. + + :param localns: local namespace + :type localns: Any + :return: None + :rtype: None + """ + globalns = sys.modules[cls.__module__].__dict__.copy() + globalns.setdefault(cls.__name__, cls) + fields_to_check = cls.Meta.model_fields.copy() + for field_name, field in fields_to_check.items(): + if issubclass(field, ForeignKeyField): + field.evaluate_forward_ref(globalns=globalns, localns=localns) + expand_reverse_relationship( + model=cls, # type: ignore + model_field=field, + ) + register_relation_in_alias_manager( + cls, # type: ignore + field, + field_name, + ) + update_column_definition(model=cls, field=field) + populate_meta_sqlalchemy_table_if_required(meta=cls.Meta) + super().update_forward_refs(**localns) + cls.Meta.requires_ref_update = False + def _get_related_not_excluded_fields( self, include: Optional[Dict], exclude: Optional[Dict], ) -> List: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 131a243..4940265 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -6,7 +6,7 @@ from sqlalchemy import bindparam import ormar # noqa I100 from ormar import MultipleMatches, NoMatch -from ormar.exceptions import ModelPersistenceError, QueryDefinitionError +from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError from ormar.queryset import FilterQuery from ormar.queryset.clause import QueryClause from ormar.queryset.prefetch_query import PrefetchQuery @@ -55,6 +55,13 @@ class QuerySet: instance: Optional[Union["QuerySet", "QuerysetProxy"]], owner: Union[Type["Model"], Type["QuerysetProxy"]], ) -> "QuerySet": + if issubclass(owner, ormar.Model): + if owner.Meta.requires_ref_update: + raise ModelError( + f"Model {owner.get_name()} has not updated " + f"ForwardRefs. \nBefore using the model you " + f"need to call update_forward_refs()." + ) if issubclass(owner, ormar.Model): return self.__class__(model_cls=owner) return self.__class__() # pragma: no cover diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index a990bfc..1e6a6cd 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -31,7 +31,6 @@ class AliasManager: """ def __init__(self) -> None: - self._aliases: Dict[str, str] = dict() self._aliases_new: Dict[str, str] = dict() @staticmethod diff --git a/tests/test_forward_refs.py b/tests/test_forward_refs.py new file mode 100644 index 0000000..2a04f40 --- /dev/null +++ b/tests/test_forward_refs.py @@ -0,0 +1,146 @@ +# type: ignore +from typing import ForwardRef, List + +import databases +import pytest +import sqlalchemy +import sqlalchemy as sa +from sqlalchemy import create_engine + +import ormar +from ormar import ModelMeta +from ormar.exceptions import ModelError +from tests.settings import DATABASE_URL + +metadata = sa.MetaData() +db = databases.Database(DATABASE_URL) +engine = create_engine(DATABASE_URL) + +Person = ForwardRef("Person") + + +class Person(ormar.Model): + class Meta(ModelMeta): + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + supervisor: Person = ormar.ForeignKey(Person, related_name="employees") + + +Person.update_forward_refs() + +Game = ForwardRef("Game") +Child = ForwardRef("Child") + + +class ChildFriends(ormar.Model): + class Meta(ModelMeta): + metadata = metadata + database = db + + +class Child(ormar.Model): + class Meta(ModelMeta): + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by") + least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by") + friends: List[Child] = ormar.ManyToMany(Child, through=ChildFriends) + + +class Game(ormar.Model): + class Meta(ModelMeta): + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +Child.update_forward_refs() + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_not_uprated_model_raises_errors(): + Person2 = ForwardRef("Person2") + + class Person2(ormar.Model): + class Meta(ModelMeta): + metadata = metadata + database = db + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + supervisor: Person2 = ormar.ForeignKey(Person2, related_name="employees") + + with pytest.raises(ModelError): + await Person2.objects.create(name="Test") + + with pytest.raises(ModelError): + Person2(name="Test") + + with pytest.raises(ModelError): + await Person2.objects.get() + + +def test_proper_field_init(): + assert "supervisor" in Person.Meta.model_fields + assert Person.Meta.model_fields["supervisor"].to == Person + + assert "supervisor" in Person.__fields__ + assert Person.__fields__["supervisor"].type_ == Person + + assert "supervisor" in Person.Meta.table.columns + assert isinstance( + Person.Meta.table.columns["supervisor"].type, sqlalchemy.sql.sqltypes.Integer + ) + assert len(Person.Meta.table.columns["supervisor"].foreign_keys) > 0 + + assert "person_supervisor" in Person.Meta.alias_manager._aliases_new + + +@pytest.mark.asyncio +async def test_self_relation(): + sam = await Person.objects.create(name="Sam") + joe = await Person(name="Joe", supervisor=sam).save() + assert joe.supervisor.name == "Sam" + + joe_check = await Person.objects.select_related("supervisor").get(name="Joe") + assert joe_check.supervisor.name == "Sam" + + sam_check = await Person.objects.select_related("employees").get(name="Sam") + assert sam_check.name == "Sam" + assert sam_check.employees[0].name == "Joe" + + +@pytest.mark.asyncio +async def test_other_forwardref_relation(): + checkers = await Game.objects.create(name="checkers") + uno = await Game(name="Uno").save() + + await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save() + await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save() + + billy_check = await Child.objects.select_related( + ["favourite_game", "least_favourite_game"] + ).get(name="Billy") + assert billy_check.favourite_game == uno + assert billy_check.least_favourite_game == checkers + + uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get( + name="Uno" + ) + assert uno_check.liked_by[0].name == "Billy" + assert uno_check.not_liked_by[0].name == "Kate"