intorduce relation flags on basefield and simplify imports

This commit is contained in:
collerek
2021-02-21 17:46:06 +01:00
parent c560245760
commit e697235172
27 changed files with 163 additions and 149 deletions

View File

@ -36,9 +36,13 @@ class BaseField(FieldInfo):
index: bool index: bool
unique: bool unique: bool
pydantic_only: bool pydantic_only: bool
virtual: bool = False
choices: typing.Sequence choices: typing.Sequence
virtual: bool = False # ManyToManyFields and reverse ForeignKeyFields
is_multi: bool = False # ManyToManyField
is_relation: bool = False # ForeignKeyField + subclasses
is_through: bool = False # ThroughFields
owner: Type["Model"] owner: Type["Model"]
to: Type["Model"] to: Type["Model"]
through: Type["Model"] through: Type["Model"]
@ -62,7 +66,7 @@ class BaseField(FieldInfo):
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
return not issubclass(cls, ormar.fields.ManyToManyField) and not cls.virtual return not cls.is_multi and not cls.virtual
@classmethod @classmethod
def get_alias(cls) -> str: def get_alias(cls) -> str:

View File

@ -48,7 +48,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
**{ **{
k: create_dummy_instance(v.to) k: create_dummy_instance(v.to)
for k, v in fk.Meta.model_fields.items() for k, v in fk.Meta.model_fields.items()
if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual if v.is_relation and not v.nullable and not v.virtual
}, },
} }
return fk(**init_dict) return fk(**init_dict)
@ -217,6 +217,7 @@ def ForeignKey( # noqa CFQ002
ondelete=ondelete, ondelete=ondelete,
owner=owner, owner=owner,
self_reference=self_reference, self_reference=self_reference,
is_relation=True,
) )
return type("ForeignKey", (ForeignKeyField, BaseField), namespace) return type("ForeignKey", (ForeignKeyField, BaseField), namespace)

View File

@ -103,6 +103,8 @@ def ManyToMany(
server_default=None, server_default=None,
owner=owner, owner=owner,
self_reference=self_reference, self_reference=self_reference,
is_relation=True,
is_multi=True,
) )
return type("ManyToMany", (ManyToManyField, BaseField), namespace) return type("ManyToMany", (ManyToManyField, BaseField), namespace)

View File

@ -15,12 +15,7 @@ if TYPE_CHECKING: # pragma no cover
def Through( # noqa CFQ002 def Through( # noqa CFQ002
to: "ToType", to: "ToType", *, name: str = None, related_name: str = None, **kwargs: Any,
*,
name: str = None,
related_name: str = None,
virtual: bool = True,
**kwargs: Any,
) -> Any: ) -> Any:
# TODO: clean docstring # TODO: clean docstring
""" """
@ -52,7 +47,7 @@ def Through( # noqa CFQ002
alias=name, alias=name,
name=kwargs.pop("real_name", None), name=kwargs.pop("real_name", None),
related_name=related_name, related_name=related_name,
virtual=virtual, virtual=True,
owner=owner, owner=owner,
nullable=False, nullable=False,
unique=False, unique=False,
@ -62,6 +57,8 @@ def Through( # noqa CFQ002
pydantic_only=False, pydantic_only=False,
default=None, default=None,
server_default=None, server_default=None,
is_relation=True,
is_through=True,
) )
return type("Through", (ThroughField, BaseField), namespace) return type("Through", (ThroughField, BaseField), namespace)

View File

@ -6,6 +6,6 @@ ass well as vast number of helper functions for pydantic, sqlalchemy and relatio
from ormar.models.newbasemodel import NewBaseModel # noqa I100 from ormar.models.newbasemodel import NewBaseModel # noqa I100
from ormar.models.model_row import ModelRow # noqa I100 from ormar.models.model_row import ModelRow # noqa I100
from ormar.models.model import Model # noqa I100 from ormar.models.model import Model, T # noqa I100
__all__ = ["NewBaseModel", "Model", "ModelRow"] __all__ = ["T", "NewBaseModel", "Model", "ModelRow"]

View File

@ -21,7 +21,7 @@ def is_field_an_forward_ref(field: Type["BaseField"]) -> bool:
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
return issubclass(field, ormar.ForeignKeyField) and ( return field.is_relation and (
field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef field.to.__class__ == ForwardRef or field.through.__class__ == ForwardRef
) )

View File

@ -6,14 +6,15 @@ from pydantic.fields import ModelField
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
import ormar # noqa: I100, I202 import ormar # noqa: I100, I202
from ormar.fields import BaseField, ManyToManyField from ormar.fields import BaseField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.fields import ManyToManyField
def create_pydantic_field( def create_pydantic_field(
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] field_name: str, model: Type["Model"], model_field: Type["ManyToManyField"]
) -> None: ) -> None:
""" """
Registers pydantic field on through model that leads to passed model Registers pydantic field on through model that leads to passed model
@ -59,7 +60,7 @@ def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField":
def populate_default_pydantic_field_value( def populate_default_pydantic_field_value(
ormar_field: Type[BaseField], field_name: str, attrs: dict ormar_field: Type["BaseField"], field_name: str, attrs: dict
) -> dict: ) -> dict:
""" """
Grabs current value of the ormar Field in class namespace Grabs current value of the ormar Field in class namespace

View File

@ -25,7 +25,7 @@ def validate_related_names_in_relations( # noqa CCR001
""" """
already_registered: Dict[str, List[Optional[str]]] = dict() already_registered: Dict[str, List[Optional[str]]] = dict()
for field in model_fields.values(): for field in model_fields.values():
if issubclass(field, ormar.ForeignKeyField): if field.is_relation:
to_name = ( to_name = (
field.to.get_name() field.to.get_name()
if not field.to.__class__ == ForwardRef if not field.to.__class__ == ForwardRef

View File

@ -1,14 +1,14 @@
from typing import TYPE_CHECKING, Type from typing import TYPE_CHECKING, Type, cast
import ormar import ormar
from ormar import ForeignKey, ManyToMany from ormar import ForeignKey, ManyToMany
from ormar.fields import ManyToManyField, Through, ThroughField from ormar.fields import Through
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.helpers.sqlalchemy import adjust_through_many_to_many_model from ormar.models.helpers.sqlalchemy import adjust_through_many_to_many_model
from ormar.relations import AliasManager from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.fields import ManyToManyField, ForeignKeyField
alias_manager = AliasManager() alias_manager = AliasManager()
@ -32,7 +32,7 @@ def register_relation_on_build(field: Type["ForeignKeyField"]) -> None:
) )
def register_many_to_many_relation_on_build(field: Type[ManyToManyField]) -> None: def register_many_to_many_relation_on_build(field: Type["ManyToManyField"]) -> None:
""" """
Registers connection between through model and both sides of the m2m relation. Registers connection between through model and both sides of the m2m relation.
Registration include also reverse relation side to be able to join both sides. Registration include also reverse relation side to be able to join both sides.
@ -83,10 +83,8 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
""" """
model_fields = list(model.Meta.model_fields.values()) model_fields = list(model.Meta.model_fields.values())
for model_field in model_fields: for model_field in model_fields:
if ( if model_field.is_relation and not model_field.has_unresolved_forward_refs():
issubclass(model_field, ForeignKeyField) model_field = cast(Type["ForeignKeyField"], model_field)
and not model_field.has_unresolved_forward_refs()
):
expand_reverse_relationship(model_field=model_field) expand_reverse_relationship(model_field=model_field)
@ -102,7 +100,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
:type model_field: relation Field :type model_field: relation Field
""" """
related_name = model_field.get_related_name() related_name = model_field.get_related_name()
if issubclass(model_field, ManyToManyField): if model_field.is_multi:
model_field.to.Meta.model_fields[related_name] = ManyToMany( model_field.to.Meta.model_fields[related_name] = ManyToMany(
model_field.owner, model_field.owner,
through=model_field.through, through=model_field.through,
@ -114,6 +112,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
self_reference_primary=model_field.self_reference_primary, self_reference_primary=model_field.self_reference_primary,
) )
# register foreign keys on through model # register foreign keys on through model
model_field = cast(Type["ManyToManyField"], model_field)
register_through_shortcut_fields(model_field=model_field) register_through_shortcut_fields(model_field=model_field)
adjust_through_many_to_many_model(model_field=model_field) adjust_through_many_to_many_model(model_field=model_field)
else: else:
@ -155,7 +154,7 @@ def register_through_shortcut_fields(model_field: Type["ManyToManyField"]) -> No
) )
def register_relation_in_alias_manager(field: Type[ForeignKeyField]) -> None: def register_relation_in_alias_manager(field: Type["ForeignKeyField"]) -> None:
""" """
Registers the relation (and reverse relation) in alias manager. Registers the relation (and reverse relation) in alias manager.
The m2m relations require registration of through model between The m2m relations require registration of through model between
@ -168,11 +167,12 @@ def register_relation_in_alias_manager(field: Type[ForeignKeyField]) -> None:
:param field: relation field :param field: relation field
:type field: ForeignKey or ManyToManyField class :type field: ForeignKey or ManyToManyField class
""" """
if issubclass(field, ManyToManyField): if field.is_multi:
if field.has_unresolved_forward_refs(): if field.has_unresolved_forward_refs():
return return
field = cast(Type["ManyToManyField"], field)
register_many_to_many_relation_on_build(field=field) register_many_to_many_relation_on_build(field=field)
elif issubclass(field, ForeignKeyField) and not issubclass(field, ThroughField): elif field.is_relation and not field.is_through:
if field.has_unresolved_forward_refs(): if field.has_unresolved_forward_refs():
return return
register_relation_on_build(field=field) register_relation_on_build(field=field)

View File

@ -156,11 +156,7 @@ def sqlalchemy_columns_from_model_fields(
field.owner = new_model field.owner = new_model
if field.primary_key: if field.primary_key:
pkname = check_pk_column_validity(field_name, field, pkname) pkname = check_pk_column_validity(field_name, field, pkname)
if ( if not field.pydantic_only and not field.virtual and not field.is_multi:
not field.pydantic_only
and not field.virtual
and not issubclass(field, ormar.ManyToManyField)
):
columns.append(field.get_column(field.get_alias())) columns.append(field.get_column(field.get_alias()))
return pkname, columns return pkname, columns

View File

@ -262,7 +262,7 @@ def copy_and_replace_m2m_through_model(
new_meta.model_fields = { new_meta.model_fields = {
name: field name: field
for name, field in new_meta.model_fields.items() for name, field in new_meta.model_fields.items()
if not issubclass(field, ForeignKeyField) if not field.is_relation
} }
_, columns = sqlalchemy_columns_from_model_fields( _, columns = sqlalchemy_columns_from_model_fields(
new_meta.model_fields, copy_through new_meta.model_fields, copy_through
@ -329,7 +329,8 @@ def copy_data_from_parent_model( # noqa: CCR001
else attrs.get("__name__", "").lower() + "s" else attrs.get("__name__", "").lower() + "s"
) )
for field_name, field in base_class.Meta.model_fields.items(): for field_name, field in base_class.Meta.model_fields.items():
if issubclass(field, ManyToManyField): if field.is_multi:
field = cast(Type["ManyToManyField"], field)
copy_and_replace_m2m_through_model( copy_and_replace_m2m_through_model(
field=field, field=field,
field_name=field_name, field_name=field_name,
@ -339,7 +340,7 @@ def copy_data_from_parent_model( # noqa: CCR001
meta=meta, meta=meta,
) )
elif issubclass(field, ForeignKeyField) and field.related_name: elif field.is_relation and field.related_name:
copy_field = type( # type: ignore copy_field = type( # type: ignore
field.__name__, (ForeignKeyField, BaseField), dict(field.__dict__) field.__name__, (ForeignKeyField, BaseField), dict(field.__dict__)
) )

View File

@ -1,9 +1,10 @@
from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type, cast
import ormar
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.mixins.relation_mixin import RelationMixin from ormar.models.mixins.relation_mixin import RelationMixin
if TYPE_CHECKING:
from ormar.fields import ForeignKeyField, ManyToManyField
class PrefetchQueryMixin(RelationMixin): class PrefetchQueryMixin(RelationMixin):
""" """
@ -39,7 +40,8 @@ class PrefetchQueryMixin(RelationMixin):
if reverse: if reverse:
field_name = parent_model.Meta.model_fields[related].get_related_name() field_name = parent_model.Meta.model_fields[related].get_related_name()
field = target_model.Meta.model_fields[field_name] field = target_model.Meta.model_fields[field_name]
if issubclass(field, ormar.fields.ManyToManyField): if field.is_multi:
field = cast(Type["ManyToManyField"], field)
field_name = field.default_target_field_name() field_name = field.default_target_field_name()
sub_field = field.through.Meta.model_fields[field_name] sub_field = field.through.Meta.model_fields[field_name]
return field.through, sub_field.get_alias() return field.through, sub_field.get_alias()
@ -87,7 +89,7 @@ class PrefetchQueryMixin(RelationMixin):
:return: name of the field :return: name of the field
:rtype: str :rtype: str
""" """
if issubclass(target_field, ormar.fields.ManyToManyField): if target_field.is_multi:
return cls.get_name() return cls.get_name()
if target_field.virtual: if target_field.virtual:
return target_field.get_related_name() return target_field.get_related_name()

View File

@ -1,10 +1,6 @@
import inspect import inspect
from typing import List, Optional, Set, TYPE_CHECKING from typing import List, Optional, Set, TYPE_CHECKING
from ormar import ManyToManyField
from ormar.fields import ThroughField
from ormar.fields.foreign_key import ForeignKeyField
class RelationMixin: class RelationMixin:
""" """
@ -62,7 +58,7 @@ class RelationMixin:
related_fields = set() related_fields = set()
for name in cls.extract_related_names(): for name in cls.extract_related_names():
field = cls.Meta.model_fields[name] field = cls.Meta.model_fields[name]
if issubclass(field, ManyToManyField): if field.is_multi:
related_fields.add(field.through.get_name(lower=True)) related_fields.add(field.through.get_name(lower=True))
return related_fields return related_fields
@ -80,11 +76,7 @@ class RelationMixin:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if ( if inspect.isclass(field) and field.is_relation and not field.is_through:
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and not issubclass(field, ThroughField)
):
related_names.add(name) related_names.add(name)
cls._related_names = related_names cls._related_names = related_names

View File

@ -8,7 +8,6 @@ from typing import (
import ormar.queryset # noqa I100 import ormar.queryset # noqa I100
from ormar.exceptions import ModelPersistenceError, NoMatch from ormar.exceptions import ModelPersistenceError, NoMatch
from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
from ormar.models.model_row import ModelRow from ormar.models.model_row import ModelRow
@ -139,8 +138,9 @@ class Model(ModelRow):
visited.add(self.__class__) visited.add(self.__class__)
for related in self.extract_related_names(): for related in self.extract_related_names():
if self.Meta.model_fields[related].virtual or issubclass( if (
self.Meta.model_fields[related], ManyToManyField self.Meta.model_fields[related].virtual
or self.Meta.model_fields[related].is_multi
): ):
for rel in getattr(self, related): for rel in getattr(self, related):
update_count, visited = await self._update_and_follow( update_count, visited = await self._update_and_follow(

View File

@ -8,24 +8,26 @@ from typing import (
Type, Type,
TypeVar, TypeVar,
Union, Union,
cast,
) )
import sqlalchemy import sqlalchemy
from ormar import ManyToManyField # noqa: I202 from ormar.models import NewBaseModel # noqa: I202
from ormar.models import NewBaseModel
from ormar.models.helpers.models import group_related_list from ormar.models.helpers.models import group_related_list
T = TypeVar("T", bound="ModelRow")
if TYPE_CHECKING: if TYPE_CHECKING:
from ormar.fields import ForeignKeyField from ormar.fields import ForeignKeyField
from ormar.models import T
else:
T = TypeVar("T", bound="ModelRow")
class ModelRow(NewBaseModel): class ModelRow(NewBaseModel):
@classmethod @classmethod
def from_row( def from_row( # noqa: CFQ002
cls: Type[T], cls: Type["ModelRow"],
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
source_model: Type[T], source_model: Type[T],
select_related: List = None, select_related: List = None,
@ -75,7 +77,7 @@ class ModelRow(NewBaseModel):
table_prefix = "" table_prefix = ""
if select_related: if select_related:
source_model = cls source_model = cast(Type[T], cls)
related_models = group_related_list(select_related) related_models = group_related_list(select_related)
if related_field: if related_field:
@ -107,7 +109,7 @@ class ModelRow(NewBaseModel):
item["__excluded__"] = cls.get_names_to_exclude( item["__excluded__"] = cls.get_names_to_exclude(
fields=fields, exclude_fields=exclude_fields fields=fields, exclude_fields=exclude_fields
) )
instance = cls(**item) instance = cast(T, cls(**item))
instance.set_save_status(True) instance.set_save_status(True)
return instance return instance
@ -160,6 +162,7 @@ class ModelRow(NewBaseModel):
else related else related
) )
field = cls.Meta.model_fields[related] field = cls.Meta.model_fields[related]
field = cast(Type["ForeignKeyField"], field)
fields = cls.get_included(fields, related) fields = cls.get_included(fields, related)
exclude_fields = cls.get_excluded(exclude_fields, related) exclude_fields = cls.get_excluded(exclude_fields, related)
model_cls = field.to model_cls = field.to
@ -177,7 +180,7 @@ class ModelRow(NewBaseModel):
source_model=source_model, source_model=source_model,
) )
item[model_cls.get_column_name_from_alias(related)] = child item[model_cls.get_column_name_from_alias(related)] = child
if issubclass(field, ManyToManyField) and child: if field.is_multi and child:
# TODO: way to figure out which side should be populated? # TODO: way to figure out which side should be populated?
through_name = cls.Meta.model_fields[related].through.get_name() through_name = cls.Meta.model_fields[related].through.get_name()
# for now it's nested dict, should be instance? # for now it's nested dict, should be instance?

View File

@ -46,15 +46,15 @@ from ormar.relations.alias_manager import AliasManager
from ormar.relations.relation_manager import RelationsManager from ormar.relations.relation_manager import RelationsManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar.models import Model, T
from ormar.signals import SignalEmitter from ormar.signals import SignalEmitter
T = TypeVar("T", bound=Model)
IntStr = Union[int, str] IntStr = Union[int, str]
DictStrAny = Dict[str, Any] DictStrAny = Dict[str, Any]
AbstractSetIntStr = AbstractSet[IntStr] AbstractSetIntStr = AbstractSet[IntStr]
MappingIntStrAny = Mapping[IntStr, Any] MappingIntStrAny = Mapping[IntStr, Any]
else:
T = TypeVar("T")
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
@ -89,7 +89,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
Meta: ModelMeta Meta: ModelMeta
# noinspection PyMissingConstructor # noinspection PyMissingConstructor
def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore def __init__(self: T, *args: Any, **kwargs: Any) -> None: # type: ignore
""" """
Initializer that creates a new ormar Model that is also pydantic Model at the Initializer that creates a new ormar Model that is also pydantic Model at the
same time. same time.
@ -129,7 +129,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__( object.__setattr__(
self, self,
"_orm", "_orm",
RelationsManager(related_fields=self.extract_related_fields(), owner=self,), RelationsManager(
related_fields=self.extract_related_fields(), owner=cast(T, self),
),
) )
pk_only = kwargs.pop("__pk_only__", False) pk_only = kwargs.pop("__pk_only__", False)
@ -298,7 +300,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def _extract_related_model_instead_of_field( def _extract_related_model_instead_of_field(
self, item: str self, item: str
) -> Optional[Union["T", Sequence["T"]]]: ) -> Optional[Union["Model", Sequence["Model"]]]:
""" """
Retrieves the related model/models from RelationshipManager. Retrieves the related model/models from RelationshipManager.
@ -755,9 +757,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
:return: value of pk if set :return: value of pk if set
:rtype: Optional[int] :rtype: Optional[int]
""" """
if target_field.virtual or issubclass( if target_field.virtual or target_field.is_multi:
target_field, ormar.fields.ManyToManyField
):
return self.pk return self.pk
related_name = target_field.name related_name = target_field.name
related_model = getattr(self, related_name) related_model = getattr(self, related_name)

View File

@ -5,6 +5,6 @@ from ormar.queryset.filter_query import FilterQuery
from ormar.queryset.limit_query import LimitQuery from ormar.queryset.limit_query import LimitQuery
from ormar.queryset.offset_query import OffsetQuery from ormar.queryset.offset_query import OffsetQuery
from ormar.queryset.order_query import OrderQuery from ormar.queryset.order_query import OrderQuery
from ormar.queryset.queryset import QuerySet from ormar.queryset.queryset import QuerySet, T
__all__ = ["QuerySet", "FilterQuery", "LimitQuery", "OffsetQuery", "OrderQuery"] __all__ = ["T", "QuerySet", "FilterQuery", "LimitQuery", "OffsetQuery", "OrderQuery"]

View File

@ -15,7 +15,6 @@ import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
from ormar.exceptions import RelationshipInstanceError # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100
from ormar.fields import BaseField, ManyToManyField # noqa I100
from ormar.relations import AliasManager from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -118,7 +117,7 @@ class SqlJoin:
:return: list of used aliases, select from, list of aliased columns, sort orders :return: list of used aliases, select from, list of aliased columns, sort orders
:rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict] :rtype: Tuple[List[str], Join, List[TextClause], collections.OrderedDict]
""" """
if issubclass(self.target_field, ManyToManyField): if self.target_field.is_multi:
self.process_m2m_through_table() self.process_m2m_through_table()
self.next_model = self.target_field.to self.next_model = self.target_field.to
@ -287,7 +286,7 @@ class SqlJoin:
) )
pkname_alias = self.next_model.get_column_alias(self.next_model.Meta.pkname) pkname_alias = self.next_model.get_column_alias(self.next_model.Meta.pkname)
if not issubclass(self.target_field, ManyToManyField): if not self.target_field.is_multi:
self.get_order_bys( self.get_order_bys(
to_table=to_table, pkname_alias=pkname_alias, to_table=to_table, pkname_alias=pkname_alias,
) )
@ -415,7 +414,7 @@ class SqlJoin:
:return: to key and from key :return: to key and from key
:rtype: Tuple[str, str] :rtype: Tuple[str, str]
""" """
if issubclass(self.target_field, ManyToManyField): if self.target_field.is_multi:
to_key = self.process_m2m_related_name_change(reverse=True) to_key = self.process_m2m_related_name_change(reverse=True)
from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname) from_key = self.main_model.get_column_alias(self.main_model.Meta.pkname)

View File

@ -13,14 +13,13 @@ from typing import (
) )
import ormar import ormar
from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from ormar import Model from ormar import Model
from ormar.fields import ForeignKeyField, BaseField
def add_relation_field_to_fields( def add_relation_field_to_fields(
@ -316,7 +315,7 @@ class PrefetchQuery:
for related in related_to_extract: for related in related_to_extract:
target_field = model.Meta.model_fields[related] target_field = model.Meta.model_fields[related]
target_field = cast(Type[ForeignKeyField], target_field) target_field = cast(Type["ForeignKeyField"], target_field)
target_model = target_field.to.get_name() target_model = target_field.to.get_name()
model_id = model.get_relation_model_id(target_field=target_field) model_id = model.get_relation_model_id(target_field=target_field)
@ -424,9 +423,9 @@ class PrefetchQuery:
fields = target_model.get_included(fields, related) fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related) exclude_fields = target_model.get_excluded(exclude_fields, related)
target_field = target_model.Meta.model_fields[related] target_field = target_model.Meta.model_fields[related]
target_field = cast(Type[ForeignKeyField], target_field) target_field = cast(Type["ForeignKeyField"], target_field)
reverse = False reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField): if target_field.virtual or target_field.is_multi:
reverse = True reverse = True
parent_model = target_model parent_model = target_model
@ -522,7 +521,7 @@ class PrefetchQuery:
select_related = [] select_related = []
query_target = target_model query_target = target_model
table_prefix = "" table_prefix = ""
if issubclass(target_field, ManyToManyField): if target_field.is_multi:
query_target = target_field.through query_target = target_field.through
select_related = [target_name] select_related = [target_name]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_alias( table_prefix = target_field.to.Meta.alias_manager.resolve_relation_alias(

View File

@ -1,4 +1,17 @@
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union from typing import (
Any,
Dict,
Generic,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Type,
TypeVar,
Union,
cast,
)
import databases import databases
import sqlalchemy import sqlalchemy
@ -14,19 +27,21 @@ from ormar.queryset.query import Query
from ormar.queryset.utils import update, update_dict_from_list from ormar.queryset.utils import update, update_dict_from_list
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar.models import T
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
from ormar.relations.querysetproxy import QuerysetProxy from ormar.relations.querysetproxy import QuerysetProxy
else:
T = TypeVar("T")
class QuerySet: class QuerySet(Generic[T]):
""" """
Main class to perform database queries, exposed on each model as objects attribute. Main class to perform database queries, exposed on each model as objects attribute.
""" """
def __init__( # noqa CFQ002 def __init__( # noqa CFQ002
self, self,
model_cls: Type["Model"] = None, model_cls: Optional[Type[T]] = None,
filter_clauses: List = None, filter_clauses: List = None,
exclude_clauses: List = None, exclude_clauses: List = None,
select_related: List = None, select_related: List = None,
@ -53,7 +68,7 @@ class QuerySet:
def __get__( def __get__(
self, self,
instance: Optional[Union["QuerySet", "QuerysetProxy"]], instance: Optional[Union["QuerySet", "QuerysetProxy"]],
owner: Union[Type["Model"], Type["QuerysetProxy"]], owner: Union[Type[T], Type["QuerysetProxy"]],
) -> "QuerySet": ) -> "QuerySet":
if issubclass(owner, ormar.Model): if issubclass(owner, ormar.Model):
if owner.Meta.requires_ref_update: if owner.Meta.requires_ref_update:
@ -62,7 +77,7 @@ class QuerySet:
f"ForwardRefs. \nBefore using the model you " f"ForwardRefs. \nBefore using the model you "
f"need to call update_forward_refs()." f"need to call update_forward_refs()."
) )
if issubclass(owner, ormar.Model): owner = cast(Type[T], owner)
return self.__class__(model_cls=owner) return self.__class__(model_cls=owner)
return self.__class__() # pragma: no cover return self.__class__() # pragma: no cover
@ -79,7 +94,7 @@ class QuerySet:
return self.model_cls.Meta return self.model_cls.Meta
@property @property
def model(self) -> Type["Model"]: def model(self) -> Type[T]:
""" """
Shortcut to model class set on QuerySet. Shortcut to model class set on QuerySet.
@ -91,8 +106,8 @@ class QuerySet:
return self.model_cls return self.model_cls
async def _prefetch_related_models( async def _prefetch_related_models(
self, models: Sequence[Optional["Model"]], rows: List self, models: Sequence[Optional["T"]], rows: List
) -> Sequence[Optional["Model"]]: ) -> Sequence[Optional["T"]]:
""" """
Performs prefetch query for selected models names. Performs prefetch query for selected models names.
@ -113,7 +128,7 @@ class QuerySet:
) )
return await query.prefetch_related(models=models, rows=rows) # type: ignore return await query.prefetch_related(models=models, rows=rows) # type: ignore
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> Sequence[Optional[T]]:
""" """
Process database rows and initialize ormar Model from each of the rows. Process database rows and initialize ormar Model from each of the rows.
@ -137,7 +152,7 @@ class QuerySet:
return result_rows return result_rows
@staticmethod @staticmethod
def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None: def check_single_result_rows_count(rows: Sequence[Optional[T]]) -> None:
""" """
Verifies if the result has one and only one row. Verifies if the result has one and only one row.
@ -198,7 +213,7 @@ class QuerySet:
limit_raw_sql=self.limit_sql_raw, limit_raw_sql=self.limit_sql_raw,
) )
exp = qry.build_select_expression() exp = qry.build_select_expression()
print("\n", exp.compile(compile_kwargs={"literal_binds": True})) # print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
return exp return exp
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -683,7 +698,7 @@ class QuerySet:
limit_raw_sql=limit_raw_sql, limit_raw_sql=limit_raw_sql,
) )
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> T:
""" """
Gets the first row from the db ordered by primary key column ascending. Gets the first row from the db ordered by primary key column ascending.
@ -707,7 +722,7 @@ class QuerySet:
self.check_single_result_rows_count(processed_rows) self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore return processed_rows[0] # type: ignore
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> T:
""" """
Get's the first row from the db meeting the criteria set by kwargs. Get's the first row from the db meeting the criteria set by kwargs.
@ -739,7 +754,7 @@ class QuerySet:
self.check_single_result_rows_count(processed_rows) self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore return processed_rows[0] # type: ignore
async def get_or_create(self, **kwargs: Any) -> "Model": async def get_or_create(self, **kwargs: Any) -> T:
""" """
Combination of create and get methods. Combination of create and get methods.
@ -757,7 +772,7 @@ class QuerySet:
except NoMatch: except NoMatch:
return await self.create(**kwargs) return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model": async def update_or_create(self, **kwargs: Any) -> T:
""" """
Updates the model, or in case there is no match in database creates a new one. Updates the model, or in case there is no match in database creates a new one.
@ -774,7 +789,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name]) model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs) return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> Sequence[Optional[T]]: # noqa: A003
""" """
Returns all rows from a database for given model for set filter options. Returns all rows from a database for given model for set filter options.
@ -798,7 +813,7 @@ class QuerySet:
return result_rows return result_rows
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> T:
""" """
Creates the model instance, saves it in a database and returns the updates model Creates the model instance, saves it in a database and returns the updates model
(with pk populated if not passed and autoincrement is set). (with pk populated if not passed and autoincrement is set).
@ -841,7 +856,7 @@ class QuerySet:
) )
return instance return instance
async def bulk_create(self, objects: List["Model"]) -> None: async def bulk_create(self, objects: List[T]) -> None:
""" """
Performs a bulk update in one database session to speed up the process. Performs a bulk update in one database session to speed up the process.
@ -867,7 +882,7 @@ class QuerySet:
objt.set_save_status(True) objt.set_save_status(True)
async def bulk_update( # noqa: CCR001 async def bulk_update( # noqa: CCR001
self, objects: List["Model"], columns: List[str] = None self, objects: List[T], columns: List[str] = None
) -> None: ) -> None:
""" """
Performs bulk update in one database session to speed up the process. Performs bulk update in one database session to speed up the process.

View File

@ -12,7 +12,6 @@ from typing import (
Union, Union,
) )
from ormar.fields import ManyToManyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -236,7 +235,7 @@ def get_relationship_alias_model_and_str(
manager = model_cls.Meta.alias_manager manager = model_cls.Meta.alias_manager
for relation in related_parts: for relation in related_parts:
related_field = model_cls.Meta.model_fields[relation] related_field = model_cls.Meta.model_fields[relation]
if issubclass(related_field, ManyToManyField): if related_field.is_multi:
previous_model = related_field.through previous_model = related_field.through
relation = related_field.default_target_field_name() # type: ignore relation = related_field.default_target_field_name() # type: ignore
table_prefix = manager.resolve_relation_alias( table_prefix = manager.resolve_relation_alias(

View File

@ -1,6 +1,7 @@
from typing import ( from typing import (
Any, Any,
Dict, Dict,
Generic,
List, List,
MutableSequence, MutableSequence,
Optional, Optional,
@ -9,6 +10,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Union, Union,
cast,
) )
import ormar import ormar
@ -16,14 +18,14 @@ from ormar.exceptions import ModelPersistenceError
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation from ormar.relations import Relation
from ormar.models import Model from ormar.models import Model, T
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
from ormar import RelationType from ormar import RelationType
else:
T = TypeVar("T", bound=Model) T = TypeVar("T")
class QuerysetProxy(ormar.QuerySetProtocol): class QuerysetProxy(Generic[T]):
""" """
Exposes QuerySet methods on relations, but also handles creating and removing Exposes QuerySet methods on relations, but also handles creating and removing
of through Models for m2m relations. of through Models for m2m relations.
@ -47,7 +49,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
self.through_model_name = ( self.through_model_name = (
self.related_field.through.get_name() self.related_field.through.get_name()
if self.type_ == ormar.RelationType.MULTIPLE if self.type_ == ormar.RelationType.MULTIPLE
else None else ""
) )
@property @property
@ -94,6 +96,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
self._assign_child_to_parent(subchild) self._assign_child_to_parent(subchild)
else: else:
assert isinstance(child, ormar.Model) assert isinstance(child, ormar.Model)
child = cast(T, child)
self._assign_child_to_parent(child) self._assign_child_to_parent(child)
def _clean_items_on_load(self) -> None: def _clean_items_on_load(self) -> None:
@ -198,7 +201,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
) )
return await queryset.delete(**kwargs) # type: ignore return await queryset.delete(**kwargs) # type: ignore
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> T:
""" """
Gets the first row from the db ordered by primary key column ascending. Gets the first row from the db ordered by primary key column ascending.
@ -216,7 +219,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
self._register_related(first) self._register_related(first)
return first return first
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "T":
""" """
Get's the first row from the db meeting the criteria set by kwargs. Get's the first row from the db meeting the criteria set by kwargs.
@ -240,7 +243,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
self._register_related(get) self._register_related(get)
return get return get
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> Sequence[Optional["T"]]: # noqa: A003
""" """
Returns all rows from a database for given model for set filter options. Returns all rows from a database for given model for set filter options.
@ -262,7 +265,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
self._register_related(all_items) self._register_related(all_items)
return all_items return all_items
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> "T":
""" """
Creates the model instance, saves it in a database and returns the updates model Creates the model instance, saves it in a database and returns the updates model
(with pk populated if not passed and autoincrement is set). (with pk populated if not passed and autoincrement is set).
@ -287,7 +290,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
await self.create_through_instance(created, **through_kwargs) await self.create_through_instance(created, **through_kwargs)
return created return created
async def get_or_create(self, **kwargs: Any) -> "Model": async def get_or_create(self, **kwargs: Any) -> "T":
""" """
Combination of create and get methods. Combination of create and get methods.
@ -305,7 +308,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
except ormar.NoMatch: except ormar.NoMatch:
return await self.create(**kwargs) return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model": async def update_or_create(self, **kwargs: Any) -> "T":
""" """
Updates the model, or in case there is no match in database creates a new one. Updates the model, or in case there is no match in database creates a new one.

View File

@ -1,17 +1,13 @@
from enum import Enum from enum import Enum
from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union from typing import List, Optional, Set, TYPE_CHECKING, Type, Union
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
from ormar.relations.relation_proxy import RelationProxy from ormar.relations.relation_proxy import RelationProxy
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.relations import RelationsManager from ormar.relations import RelationsManager
from ormar.models import NewBaseModel from ormar.models import Model, NewBaseModel, T
T = TypeVar("T", bound=Model)
class RelationType(Enum): class RelationType(Enum):
@ -39,7 +35,7 @@ class Relation:
manager: "RelationsManager", manager: "RelationsManager",
type_: RelationType, type_: RelationType,
field_name: str, field_name: str,
to: Type["T"], to: Type["Model"],
through: Type["T"] = None, through: Type["T"] = None,
) -> None: ) -> None:
""" """
@ -63,10 +59,10 @@ class Relation:
self._owner: "Model" = manager.owner self._owner: "Model" = manager.owner
self._type: RelationType = type_ self._type: RelationType = type_
self._to_remove: Set = set() self._to_remove: Set = set()
self.to: Type["T"] = to self.to: Type["Model"] = to
self._through: Optional[Type["T"]] = through self._through = through
self.field_name: str = field_name self.field_name: str = field_name
self.related_models: Optional[Union[RelationProxy, "T"]] = ( self.related_models: Optional[Union[RelationProxy, "Model"]] = (
RelationProxy(relation=self, type_=type_, field_name=field_name) RelationProxy(relation=self, type_=type_, field_name=field_name)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
@ -161,7 +157,7 @@ class Relation:
self.related_models.pop(position) # type: ignore self.related_models.pop(position) # type: ignore
del self._owner.__dict__[relation_name][position] del self._owner.__dict__[relation_name][position]
def get(self) -> Optional[Union[List["T"], "T"]]: def get(self) -> Optional[Union[List["Model"], "Model"]]:
""" """
Return the related model or models from RelationProxy. Return the related model or models from RelationProxy.

View File

@ -1,17 +1,12 @@
from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, TypeVar, Union from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, Union
from weakref import proxy from weakref import proxy
from ormar.fields import BaseField, ThroughField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.relation import Relation, RelationType from ormar.relations.relation import Relation, RelationType
from ormar.relations.utils import get_relations_sides_and_names from ormar.relations.utils import get_relations_sides_and_names
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar.models import NewBaseModel, T, Model
from ormar.models import NewBaseModel from ormar.fields import ForeignKeyField, BaseField
T = TypeVar("T", bound=Model)
class RelationsManager: class RelationsManager:
@ -21,8 +16,8 @@ class RelationsManager:
def __init__( def __init__(
self, self,
related_fields: List[Type[ForeignKeyField]] = None, related_fields: List[Type["ForeignKeyField"]] = None,
owner: "NewBaseModel" = None, owner: Optional["T"] = None,
) -> None: ) -> None:
self.owner = proxy(owner) self.owner = proxy(owner)
self._related_fields = related_fields or [] self._related_fields = related_fields or []
@ -31,7 +26,7 @@ class RelationsManager:
for field in self._related_fields: for field in self._related_fields:
self._add_relation(field) self._add_relation(field)
def _get_relation_type(self, field: Type[BaseField]) -> RelationType: def _get_relation_type(self, field: Type["BaseField"]) -> RelationType:
""" """
Returns type of the relation declared on a field. Returns type of the relation declared on a field.
@ -40,13 +35,13 @@ class RelationsManager:
:return: type of the relation defined on field :return: type of the relation defined on field
:rtype: RelationType :rtype: RelationType
""" """
if issubclass(field, ManyToManyField): if field.is_multi:
return RelationType.MULTIPLE return RelationType.MULTIPLE
if issubclass(field, ThroughField): if field.is_through:
return RelationType.THROUGH return RelationType.THROUGH
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
def _add_relation(self, field: Type[BaseField]) -> None: def _add_relation(self, field: Type["BaseField"]) -> None:
""" """
Registers relation in the manager. Registers relation in the manager.
Adds Relation instance under field.name. Adds Relation instance under field.name.
@ -73,7 +68,7 @@ class RelationsManager:
""" """
return item in self._related_names return item in self._related_names
def get(self, name: str) -> Optional[Union["T", Sequence["T"]]]: def get(self, name: str) -> Optional[Union["Model", Sequence["Model"]]]:
""" """
Returns the related model/models if relation is set. Returns the related model/models if relation is set.
Actual call is delegated to Relation instance registered under relation name. Actual call is delegated to Relation instance registered under relation name.

View File

@ -27,7 +27,9 @@ class RelationProxy(list):
self.type_: "RelationType" = type_ self.type_: "RelationType" = type_
self.field_name = field_name self.field_name = field_name
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_) self.queryset_proxy: QuerysetProxy = QuerysetProxy(
relation=self.relation, type_=type_
)
self._related_field_name: Optional[str] = None self._related_field_name: Optional[str] = None
@property @property

BIN
test.db-journal Normal file

Binary file not shown.

View File

@ -1,3 +1,5 @@
from typing import Any
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
@ -19,8 +21,8 @@ class Category(ormar.Model):
class Meta(BaseMeta): class Meta(BaseMeta):
tablename = "categories" tablename = "categories"
id: int = ormar.Integer(primary_key=True) id = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40) name = ormar.String(max_length=40)
class PostCategory(ormar.Model): class PostCategory(ormar.Model):
@ -107,8 +109,12 @@ async def test_setting_additional_fields_on_through_model_in_create():
assert postcat.sort_order == 2 assert postcat.sort_order == 2
def process_post(post: Post):
pass
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_getting_additional_fields_from_queryset(): async def test_getting_additional_fields_from_queryset() -> Any:
async with database: async with database:
post = await Post(title="Test post").save() post = await Post(title="Test post").save()
await post.categories.create( await post.categories.create(
@ -122,10 +128,11 @@ async def test_getting_additional_fields_from_queryset():
assert post.categories[0].postcategory.sort_order == 1 assert post.categories[0].postcategory.sort_order == 1
assert post.categories[1].postcategory.sort_order == 2 assert post.categories[1].postcategory.sort_order == 2
post = await Post.objects.select_related("categories").get( post2 = await Post.objects.select_related("categories").get(
categories__name="Test category2" categories__name="Test category2"
) )
assert post.categories[0].postcategory.sort_order == 2 assert post2.categories[0].postcategory.sort_order == 2
process_post(post2)
# TODO: check/ modify following # TODO: check/ modify following