Merge pull request #132 from collerek/fields_instances

Fields instances
This commit is contained in:
collerek
2021-03-23 10:04:09 +01:00
committed by GitHub
40 changed files with 886 additions and 543 deletions

View File

@ -6,6 +6,12 @@ checks:
argument-count: argument-count:
config: config:
threshold: 6 threshold: 6
method-count:
config:
threshold: 25
method-length:
config:
threshold: 35
file-lines: file-lines:
config: config:
threshold: 500 threshold: 500

View File

@ -3,6 +3,7 @@
To join one table to another, so load also related models you can use following methods. To join one table to another, so load also related models you can use following methods.
* `select_related(related: Union[List, str]) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet`
* `select_all(follow: bool = True) -> QuerySet`
* `prefetch_related(related: Union[List, str]) -> QuerySet` * `prefetch_related(related: Union[List, str]) -> QuerySet`
@ -12,6 +13,7 @@ To join one table to another, so load also related models you can use following
* `QuerysetProxy` * `QuerysetProxy`
* `QuerysetProxy.select_related(related: Union[List, str])` method * `QuerysetProxy.select_related(related: Union[List, str])` method
* `QuerysetProxy.select_all(follow: bool=True)` method
* `QuerysetProxy.prefetch_related(related: Union[List, str])` method * `QuerysetProxy.prefetch_related(related: Union[List, str])` method
## select_related ## select_related
@ -142,6 +144,92 @@ fields and the final `Models` are fetched for you.
Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()`
## select_all
`select_all(follow: bool = False) -> QuerySet`
By default when you select `all()` none of the relations are loaded, likewise,
when `select_related()` is used you need to explicitly specify all relations that should
be loaded. If you want to include also nested relations this can be cumberstone.
That's why `select_all()` was introduced, so by default load all relations of a model
(so kind of opposite as with `all()` approach).
By default adds only directly related models of a parent model (from which the query is run).
If `follow=True` is set it adds also related models of related models.
!!!info
To not get stuck in an infinite loop as related models also keep a relation
to parent model visited models set is kept.
That way already visited models that are nested are loaded, but the load do not
follow them inside. So Model A -> Model B -> Model C -> Model A -> Model X
will load second Model A but will never follow into Model X.
Nested relations of those kind need to be loaded manually.
With sample date like follow:
```python
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
class Address(ormar.Model):
class Meta(BaseMeta):
tablename = "addresses"
id: int = ormar.Integer(primary_key=True)
street: str = ormar.String(max_length=100, nullable=False)
number: int = ormar.Integer(nullable=False)
post_code: str = ormar.String(max_length=20, nullable=False)
class Branch(ormar.Model):
class Meta(BaseMeta):
tablename = "branches"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False)
address = ormar.ForeignKey(Address)
class Company(ormar.Model):
class Meta(BaseMeta):
tablename = "companies"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
founded: int = ormar.Integer(nullable=True)
branches = ormar.ManyToMany(Branch)
```
To select all `Companies` with all `Branches` and `Addresses` you can simply query:
```python
companies = await Company.objects.select_all(follow=True).all()
# which is equivalent to:
companies = await Company.objects.select_related('branches__address').all()
```
Of course in this case it's quite easy to issue explicit relation names in `select_related`,
but the benefit of `select_all()` shows when you have multiple relations.
If for example `Company` would have 3 relations and all of those 3 relations have it's own
3 relations you would have to issue 9 relation strings to `select_related`, `select_all()`
is also resistant to change in names of relations.
!!!note
Note that you can chain `select_all()` with other `QuerySet` methods like `filter`, `exclude_fields` etc.
To exclude relations use `exclude_fields()` call with names of relations (also nested) to exclude.
## prefetch_related ## prefetch_related
`prefetch_related(related: Union[List, str]) -> QuerySet` `prefetch_related(related: Union[List, str]) -> QuerySet`
@ -404,6 +492,15 @@ objects from other side of the relation.
!!!tip !!!tip
To read more about `QuerysetProxy` visit [querysetproxy][querysetproxy] section To read more about `QuerysetProxy` visit [querysetproxy][querysetproxy] section
### select_all
Works exactly the same as [select_all](./#select_all) function above but allows you to fetch related
objects from other side of the relation.
!!!tip
To read more about `QuerysetProxy` visit [querysetproxy][querysetproxy] section
### prefetch_related ### prefetch_related
Works exactly the same as [prefetch_related](./#prefetch_related) function above but allows you to fetch related Works exactly the same as [prefetch_related](./#prefetch_related) function above but allows you to fetch related

View File

@ -1,3 +1,26 @@
# 0.10.0
## Breaking
* Dropped supported for long deprecated notation of field definition in which you use ormar fields as type hints i.e. `test_field: ormar.Integger() = None`
* Improved type hints -> `mypy` can properly resolve related models fields (`ForeignKey` and `ManyToMany`) as well as return types of `QuerySet` methods.
Those mentioned are now returning proper model (i.e. `Book`) instead or `ormar.Model` type. There is still problem with reverse sides of relation and `QuerysetProxy` methods,
to ease type hints now those return `Any`. Partially fixes #112.
## Features
* add `select_all(follow: bool = False)` method to `QuerySet` and `QuerysetProxy`.
It is kind of equivalent of the Model's `load_all()` method but can be used directly in a query.
By default `select_all()` adds only directly related models, with `follow=True` also related models
of related models are added without loops in relations. Note that it's not and end `async` model
so you still have to issue `get()`, `all()` etc. as `select_all()` returns a QuerySet (or proxy)
like `fields()` or `order_by()`.
## Internals
* `ormar` fields are no longer stored as classes in `Meta.model_fields` dictionary
but instead they are stored as instances.
# 0.9.9 # 0.9.9
## Features ## Features

View File

@ -75,7 +75,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.9.9" __version__ = "0.10.0"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -1,7 +1,7 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy import sqlalchemy
from pydantic import Field, Json, typing from pydantic import Json, typing
from pydantic.fields import FieldInfo, Required, Undefined from pydantic.fields import FieldInfo, Required, Undefined
import ormar # noqa I101 import ormar # noqa I101
@ -28,44 +28,62 @@ class BaseField(FieldInfo):
to pydantic field types like ConstrainedStr to pydantic field types like ConstrainedStr
""" """
__type__ = None def __init__(self, **kwargs: Any) -> None:
related_name = None self.__type__: type = kwargs.pop("__type__", None)
self.related_name = kwargs.pop("related_name", None)
column_type: sqlalchemy.Column self.column_type: sqlalchemy.Column = kwargs.pop("column_type", None)
constraints: List = [] self.constraints: List = kwargs.pop("constraints", list())
name: str self.name: str = kwargs.pop("name", None)
alias: str self.db_alias: str = kwargs.pop("alias", None)
primary_key: bool self.primary_key: bool = kwargs.pop("primary_key", False)
autoincrement: bool self.autoincrement: bool = kwargs.pop("autoincrement", False)
nullable: bool self.nullable: bool = kwargs.pop("nullable", False)
index: bool self.index: bool = kwargs.pop("index", False)
unique: bool self.unique: bool = kwargs.pop("unique", False)
pydantic_only: bool self.pydantic_only: bool = kwargs.pop("pydantic_only", False)
choices: typing.Sequence self.choices: typing.Sequence = kwargs.pop("choices", False)
virtual: bool = False # ManyToManyFields and reverse ForeignKeyFields self.virtual: bool = kwargs.pop(
is_multi: bool = False # ManyToManyField "virtual", None
is_relation: bool = False # ForeignKeyField + subclasses ) # ManyToManyFields and reverse ForeignKeyFields
is_through: bool = False # ThroughFields self.is_multi: bool = kwargs.pop("is_multi", None) # ManyToManyField
self.is_relation: bool = kwargs.pop(
"is_relation", None
) # ForeignKeyField + subclasses
self.is_through: bool = kwargs.pop("is_through", False) # ThroughFields
owner: Type["Model"] self.owner: Type["Model"] = kwargs.pop("owner", None)
to: Type["Model"] self.to: Type["Model"] = kwargs.pop("to", None)
through: Type["Model"] self.through: Type["Model"] = kwargs.pop("through", None)
self_reference: bool = False self.self_reference: bool = kwargs.pop("self_reference", False)
self_reference_primary: Optional[str] = None self.self_reference_primary: Optional[str] = kwargs.pop(
orders_by: Optional[List[str]] = None "self_reference_primary", None
related_orders_by: Optional[List[str]] = None )
self.orders_by: Optional[List[str]] = kwargs.pop("orders_by", None)
self.related_orders_by: Optional[List[str]] = kwargs.pop(
"related_orders_by", None
)
encrypt_secret: str self.encrypt_secret: str = kwargs.pop("encrypt_secret", None)
encrypt_backend: EncryptBackends = EncryptBackends.NONE self.encrypt_backend: EncryptBackends = kwargs.pop(
encrypt_custom_backend: Optional[Type[EncryptBackend]] = None "encrypt_backend", EncryptBackends.NONE
)
self.encrypt_custom_backend: Optional[Type[EncryptBackend]] = kwargs.pop(
"encrypt_custom_backend", None
)
default: Any self.ormar_default: Any = kwargs.pop("default", None)
server_default: Any self.server_default: Any = kwargs.pop("server_default", None)
@classmethod for name, value in kwargs.items():
def is_valid_uni_relation(cls) -> bool: setattr(self, name, value)
kwargs.update(self.get_pydantic_default())
super().__init__(**kwargs)
def is_valid_uni_relation(self) -> bool:
""" """
Checks if field is a relation definition but only for ForeignKey relation, Checks if field is a relation definition but only for ForeignKey relation,
so excludes ManyToMany fields, as well as virtual ForeignKey so excludes ManyToMany fields, as well as virtual ForeignKey
@ -78,10 +96,9 @@ class BaseField(FieldInfo):
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
return not cls.is_multi and not cls.virtual return not self.is_multi and not self.virtual
@classmethod def get_alias(self) -> str:
def get_alias(cls) -> str:
""" """
Used to translate Model column names to database column names during db queries. Used to translate Model column names to database column names during db queries.
@ -89,73 +106,25 @@ class BaseField(FieldInfo):
otherwise field name in ormar/pydantic otherwise field name in ormar/pydantic
:rtype: str :rtype: str
""" """
return cls.alias if cls.alias else cls.name return self.db_alias if self.db_alias else self.name
@classmethod def get_pydantic_default(self) -> Dict:
def is_valid_field_info_field(cls, field_name: str) -> bool:
"""
Checks if field belongs to pydantic FieldInfo
- used during setting default pydantic values.
Excludes defaults and alias as they are populated separately
(defaults) or not at all (alias)
:param field_name: field name of BaseFIeld
:type field_name: str
:return: True if field is present on pydantic.FieldInfo
:rtype: bool
"""
return (
field_name not in ["default", "default_factory", "alias", "allow_mutation"]
and not field_name.startswith("__")
and hasattr(cls, field_name)
and not callable(getattr(cls, field_name))
)
@classmethod
def get_base_pydantic_field_info(cls, allow_null: bool) -> FieldInfo:
""" """
Generates base pydantic.FieldInfo with only default and optionally Generates base pydantic.FieldInfo with only default and optionally
required to fix pydantic Json field being set to required=False. required to fix pydantic Json field being set to required=False.
Used in an ormar Model Metaclass. Used in an ormar Model Metaclass.
:param allow_null: flag if the default value can be None
or if it should be populated by pydantic Undefined
:type allow_null: bool
:return: instance of base pydantic.FieldInfo :return: instance of base pydantic.FieldInfo
:rtype: pydantic.FieldInfo :rtype: pydantic.FieldInfo
""" """
base = cls.default_value() base = self.default_value()
if base is None: if base is None:
base = ( base = dict(default=None) if self.nullable else dict(default=Undefined)
FieldInfo(default=None) if self.__type__ == Json and base.get("default") is Undefined:
if (cls.nullable or allow_null) base["default"] = Required
else FieldInfo(default=Undefined)
)
if cls.__type__ == Json and base.default is Undefined:
base.default = Required
return base return base
@classmethod def default_value(self, use_server: bool = False) -> Optional[Dict]:
def convert_to_pydantic_field_info(cls, allow_null: bool = False) -> FieldInfo:
"""
Converts a BaseField into pydantic.FieldInfo
that is later easily processed by pydantic.
Used in an ormar Model Metaclass.
:param allow_null: flag if the default value can be None
or if it should be populated by pydantic Undefined
:type allow_null: bool
:return: actual instance of pydantic.FieldInfo with all needed fields populated
:rtype: pydantic.FieldInfo
"""
base = cls.get_base_pydantic_field_info(allow_null=allow_null)
for attr_name in FieldInfo.__dict__.keys():
if cls.is_valid_field_info_field(attr_name):
setattr(base, attr_name, cls.__dict__.get(attr_name))
return base
@classmethod
def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]:
""" """
Returns a FieldInfo instance with populated default Returns a FieldInfo instance with populated default
(static) or default_factory (function). (static) or default_factory (function).
@ -173,17 +142,20 @@ class BaseField(FieldInfo):
which is returning a FieldInfo instance which is returning a FieldInfo instance
:rtype: Optional[pydantic.FieldInfo] :rtype: Optional[pydantic.FieldInfo]
""" """
if cls.is_auto_primary_key(): if self.is_auto_primary_key():
return Field(default=None) return dict(default=None)
if cls.has_default(use_server=use_server): if self.has_default(use_server=use_server):
default = cls.default if cls.default is not None else cls.server_default default = (
self.ormar_default
if self.ormar_default is not None
else self.server_default
)
if callable(default): if callable(default):
return Field(default_factory=default) return dict(default_factory=default)
return Field(default=default) return dict(default=default)
return None return None
@classmethod def get_default(self, use_server: bool = False) -> Any: # noqa CCR001
def get_default(cls, use_server: bool = False) -> Any: # noqa CCR001
""" """
Return default value for a field. Return default value for a field.
If the field is Callable the function is called and actual result is returned. If the field is Callable the function is called and actual result is returned.
@ -195,18 +167,17 @@ class BaseField(FieldInfo):
:return: default value for the field if set, otherwise implicit None :return: default value for the field if set, otherwise implicit None
:rtype: Any :rtype: Any
""" """
if cls.has_default(): if self.has_default():
default = ( default = (
cls.default self.ormar_default
if cls.default is not None if self.ormar_default is not None
else (cls.server_default if use_server else None) else (self.server_default if use_server else None)
) )
if callable(default): if callable(default):
default = default() default = default()
return default return default
@classmethod def has_default(self, use_server: bool = True) -> bool:
def has_default(cls, use_server: bool = True) -> bool:
""" """
Checks if the field has default value set. Checks if the field has default value set.
@ -216,12 +187,11 @@ class BaseField(FieldInfo):
:return: result of the check if default value is set :return: result of the check if default value is set
:rtype: bool :rtype: bool
""" """
return cls.default is not None or ( return self.ormar_default is not None or (
cls.server_default is not None and use_server self.server_default is not None and use_server
) )
@classmethod def is_auto_primary_key(self) -> bool:
def is_auto_primary_key(cls) -> bool:
""" """
Checks if field is first a primary key and if it, Checks if field is first a primary key and if it,
it's than check if it's set to autoincrement. it's than check if it's set to autoincrement.
@ -230,12 +200,11 @@ class BaseField(FieldInfo):
:return: result of the check for primary key and autoincrement :return: result of the check for primary key and autoincrement
:rtype: bool :rtype: bool
""" """
if cls.primary_key: if self.primary_key:
return cls.autoincrement return self.autoincrement
return False return False
@classmethod def construct_constraints(self) -> List:
def construct_constraints(cls) -> List:
""" """
Converts list of ormar constraints into sqlalchemy ForeignKeys. Converts list of ormar constraints into sqlalchemy ForeignKeys.
Has to be done dynamically as sqlalchemy binds ForeignKey to the table. Has to be done dynamically as sqlalchemy binds ForeignKey to the table.
@ -249,15 +218,14 @@ class BaseField(FieldInfo):
con.reference, con.reference,
ondelete=con.ondelete, ondelete=con.ondelete,
onupdate=con.onupdate, onupdate=con.onupdate,
name=f"fk_{cls.owner.Meta.tablename}_{cls.to.Meta.tablename}" name=f"fk_{self.owner.Meta.tablename}_{self.to.Meta.tablename}"
f"_{cls.to.get_column_alias(cls.to.Meta.pkname)}_{cls.name}", f"_{self.to.get_column_alias(self.to.Meta.pkname)}_{self.name}",
) )
for con in cls.constraints for con in self.constraints
] ]
return constraints return constraints
@classmethod def get_column(self, name: str) -> sqlalchemy.Column:
def get_column(cls, name: str) -> sqlalchemy.Column:
""" """
Returns definition of sqlalchemy.Column used in creation of sqlalchemy.Table. Returns definition of sqlalchemy.Column used in creation of sqlalchemy.Table.
Populates name, column type constraints, as well as a number of parameters like Populates name, column type constraints, as well as a number of parameters like
@ -268,24 +236,23 @@ class BaseField(FieldInfo):
:return: actual definition of the database column as sqlalchemy requires. :return: actual definition of the database column as sqlalchemy requires.
:rtype: sqlalchemy.Column :rtype: sqlalchemy.Column
""" """
if cls.encrypt_backend == EncryptBackends.NONE: if self.encrypt_backend == EncryptBackends.NONE:
column = sqlalchemy.Column( column = sqlalchemy.Column(
cls.alias or name, self.db_alias or name,
cls.column_type, self.column_type,
*cls.construct_constraints(), *self.construct_constraints(),
primary_key=cls.primary_key, primary_key=self.primary_key,
nullable=cls.nullable and not cls.primary_key, nullable=self.nullable and not self.primary_key,
index=cls.index, index=self.index,
unique=cls.unique, unique=self.unique,
default=cls.default, default=self.ormar_default,
server_default=cls.server_default, server_default=self.server_default,
) )
else: else:
column = cls._get_encrypted_column(name=name) column = self._get_encrypted_column(name=name)
return column return column
@classmethod def _get_encrypted_column(self, name: str) -> sqlalchemy.Column:
def _get_encrypted_column(cls, name: str) -> sqlalchemy.Column:
""" """
Returns EncryptedString column type instead of actual column. Returns EncryptedString column type instead of actual column.
@ -294,29 +261,28 @@ class BaseField(FieldInfo):
:return: newly defined column :return: newly defined column
:rtype: sqlalchemy.Column :rtype: sqlalchemy.Column
""" """
if cls.primary_key or cls.is_relation: if self.primary_key or self.is_relation:
raise ModelDefinitionError( raise ModelDefinitionError(
"Primary key field and relations fields" "cannot be encrypted!" "Primary key field and relations fields" "cannot be encrypted!"
) )
column = sqlalchemy.Column( column = sqlalchemy.Column(
cls.alias or name, self.db_alias or name,
EncryptedString( EncryptedString(
_field_type=cls, _field_type=self,
encrypt_secret=cls.encrypt_secret, encrypt_secret=self.encrypt_secret,
encrypt_backend=cls.encrypt_backend, encrypt_backend=self.encrypt_backend,
encrypt_custom_backend=cls.encrypt_custom_backend, encrypt_custom_backend=self.encrypt_custom_backend,
), ),
nullable=cls.nullable, nullable=self.nullable,
index=cls.index, index=self.index,
unique=cls.unique, unique=self.unique,
default=cls.default, default=self.ormar_default,
server_default=cls.server_default, server_default=self.server_default,
) )
return column return column
@classmethod
def expand_relationship( def expand_relationship(
cls, self,
value: Any, value: Any,
child: Union["Model", "NewBaseModel"], child: Union["Model", "NewBaseModel"],
to_register: bool = True, to_register: bool = True,
@ -339,21 +305,19 @@ class BaseField(FieldInfo):
""" """
return value return value
@classmethod def set_self_reference_flag(self) -> None:
def set_self_reference_flag(cls) -> None:
""" """
Sets `self_reference` to True if field to and owner are same model. Sets `self_reference` to True if field to and owner are same model.
:return: None :return: None
:rtype: None :rtype: None
""" """
if cls.owner is not None and ( if self.owner is not None and (
cls.owner == cls.to or cls.owner.Meta == cls.to.Meta self.owner == self.to or self.owner.Meta == self.to.Meta
): ):
cls.self_reference = True self.self_reference = True
cls.self_reference_primary = cls.name self.self_reference_primary = self.name
@classmethod def has_unresolved_forward_refs(self) -> bool:
def has_unresolved_forward_refs(cls) -> bool:
""" """
Verifies if the filed has any ForwardRefs that require updating before the Verifies if the filed has any ForwardRefs that require updating before the
model can be used. model can be used.
@ -363,8 +327,7 @@ class BaseField(FieldInfo):
""" """
return False return False
@classmethod def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None:
def evaluate_forward_ref(cls, globalns: Any, localns: Any) -> None:
""" """
Evaluates the ForwardRef to actual Field based on global and local namespaces Evaluates the ForwardRef to actual Field based on global and local namespaces
@ -376,8 +339,7 @@ class BaseField(FieldInfo):
:rtype: None :rtype: None
""" """
@classmethod def get_related_name(self) -> str:
def get_related_name(cls) -> str:
""" """
Returns name to use for reverse relation. Returns name to use for reverse relation.
It's either set as `related_name` or by default it's owner model. get_name + 's' It's either set as `related_name` or by default it's owner model. get_name + 's'

View File

@ -3,7 +3,17 @@ import sys
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from random import choices from random import choices
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union from typing import (
Any,
Dict,
List,
Optional,
TYPE_CHECKING,
Tuple,
Type,
Union,
overload,
)
import sqlalchemy import sqlalchemy
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -15,16 +25,16 @@ from ormar.exceptions import ModelDefinitionError, RelationshipInstanceError
from ormar.fields.base import BaseField from ormar.fields.base import BaseField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import Model, NewBaseModel from ormar.models import Model, NewBaseModel, T
from ormar.fields import ManyToManyField from ormar.fields import ManyToManyField
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
ToType = Type["Model"] ToType = Type["T"]
else: else:
ToType = Union[Type["Model"], "ForwardRef"] ToType = Union[Type["T"], "ForwardRef"]
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def create_dummy_instance(fk: Type["T"], pk: Any = None) -> "T":
""" """
Ormar never returns you a raw data. Ormar never returns you a raw data.
So if you have a related field that has a value populated So if you have a related field that has a value populated
@ -55,8 +65,8 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
def create_dummy_model( def create_dummy_model(
base_model: Type["Model"], base_model: Type["T"],
pk_field: Type[Union[BaseField, "ForeignKeyField", "ManyToManyField"]], pk_field: Union[BaseField, "ForeignKeyField", "ManyToManyField"],
) -> Type["BaseModel"]: ) -> Type["BaseModel"]:
""" """
Used to construct a dummy pydantic model for type hints and pydantic validation. Used to construct a dummy pydantic model for type hints and pydantic validation.
@ -65,7 +75,7 @@ def create_dummy_model(
:param base_model: class of target dummy model :param base_model: class of target dummy model
:type base_model: Model class :type base_model: Model class
:param pk_field: ormar Field to be set on pydantic Model :param pk_field: ormar Field to be set on pydantic Model
:type pk_field: Type[Union[BaseField, "ForeignKeyField", "ManyToManyField"]] :type pk_field: Union[BaseField, "ForeignKeyField", "ManyToManyField"]
:return: constructed dummy model :return: constructed dummy model
:rtype: pydantic.BaseModel :rtype: pydantic.BaseModel
""" """
@ -83,7 +93,7 @@ def create_dummy_model(
def populate_fk_params_based_on_to_model( def populate_fk_params_based_on_to_model(
to: Type["Model"], nullable: bool, onupdate: str = None, ondelete: str = None, to: Type["T"], nullable: bool, onupdate: str = None, ondelete: str = None,
) -> Tuple[Any, List, Any]: ) -> Tuple[Any, List, Any]:
""" """
Based on target to model to which relation leads to populates the type of the Based on target to model to which relation leads to populates the type of the
@ -168,6 +178,16 @@ class ForeignKeyConstraint:
onupdate: Optional[str] onupdate: Optional[str]
@overload
def ForeignKey(to: Type["T"], **kwargs: Any) -> "T": # pragma: no cover
...
@overload
def ForeignKey(to: ForwardRef, **kwargs: Any) -> "Model": # pragma: no cover
...
def ForeignKey( # noqa CFQ002 def ForeignKey( # noqa CFQ002
to: "ToType", to: "ToType",
*, *,
@ -179,7 +199,7 @@ def ForeignKey( # noqa CFQ002
onupdate: str = None, onupdate: str = None,
ondelete: str = None, ondelete: str = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> "T":
""" """
Despite a name it's a function that returns constructed ForeignKeyField. Despite a name it's a function that returns constructed ForeignKeyField.
This function is actually used in model declaration (as ormar.ForeignKey(ToModel)). This function is actually used in model declaration (as ormar.ForeignKey(ToModel)).
@ -256,7 +276,8 @@ def ForeignKey( # noqa CFQ002
related_orders_by=related_orders_by, related_orders_by=related_orders_by,
) )
return type("ForeignKey", (ForeignKeyField, BaseField), namespace) Field = type("ForeignKey", (ForeignKeyField, BaseField), {})
return Field(**namespace)
class ForeignKeyField(BaseField): class ForeignKeyField(BaseField):
@ -264,15 +285,15 @@ class ForeignKeyField(BaseField):
Actual class returned from ForeignKey function call and stored in model_fields. Actual class returned from ForeignKey function call and stored in model_fields.
""" """
to: Type["Model"] def __init__(self, **kwargs: Any) -> None:
name: str if TYPE_CHECKING: # pragma: no cover
related_name: str # type: ignore self.__type__: type
virtual: bool self.to: Type["Model"]
ondelete: str self.ondelete: str = kwargs.pop("ondelete", None)
onupdate: str self.onupdate: str = kwargs.pop("onupdate", None)
super().__init__(**kwargs)
@classmethod def get_source_related_name(self) -> str:
def get_source_related_name(cls) -> str:
""" """
Returns name to use for source relation name. Returns name to use for source relation name.
For FK it's the same, differs for m2m fields. For FK it's the same, differs for m2m fields.
@ -280,20 +301,18 @@ class ForeignKeyField(BaseField):
:return: name of the related_name or default related name. :return: name of the related_name or default related name.
:rtype: str :rtype: str
""" """
return cls.get_related_name() return self.get_related_name()
@classmethod def get_related_name(self) -> str:
def get_related_name(cls) -> str:
""" """
Returns name to use for reverse relation. Returns name to use for reverse relation.
It's either set as `related_name` or by default it's owner model. get_name + 's' It's either set as `related_name` or by default it's owner model. get_name + 's'
:return: name of the related_name or default related name. :return: name of the related_name or default related name.
:rtype: str :rtype: str
""" """
return cls.related_name or cls.owner.get_name() + "s" return self.related_name or self.owner.get_name() + "s"
@classmethod def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None:
def evaluate_forward_ref(cls, globalns: Any, localns: Any) -> None:
""" """
Evaluates the ForwardRef to actual Field based on global and local namespaces Evaluates the ForwardRef to actual Field based on global and local namespaces
@ -304,26 +323,25 @@ class ForeignKeyField(BaseField):
:return: None :return: None
:rtype: None :rtype: None
""" """
if cls.to.__class__ == ForwardRef: if self.to.__class__ == ForwardRef:
cls.to = evaluate_forwardref( self.to = evaluate_forwardref(
cls.to, # type: ignore self.to, # type: ignore
globalns, globalns,
localns or None, localns or None,
) )
( (
cls.__type__, self.__type__,
cls.constraints, self.constraints,
cls.column_type, self.column_type,
) = populate_fk_params_based_on_to_model( ) = populate_fk_params_based_on_to_model(
to=cls.to, to=self.to,
nullable=cls.nullable, nullable=self.nullable,
ondelete=cls.ondelete, ondelete=self.ondelete,
onupdate=cls.onupdate, onupdate=self.onupdate,
) )
@classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(
cls, value: List, child: "Model", to_register: bool, self, value: List, child: "Model", to_register: bool,
) -> List["Model"]: ) -> List["Model"]:
""" """
Takes a list of Models and registers them on parent. Takes a list of Models and registers them on parent.
@ -341,15 +359,14 @@ class ForeignKeyField(BaseField):
:rtype: List["Model"] :rtype: List["Model"]
""" """
return [ return [
cls.expand_relationship( # type: ignore self.expand_relationship( # type: ignore
value=val, child=child, to_register=to_register, value=val, child=child, to_register=to_register,
) )
for val in value for val in value
] ]
@classmethod
def _register_existing_model( def _register_existing_model(
cls, value: "Model", child: "Model", to_register: bool, self, value: "Model", child: "Model", to_register: bool,
) -> "Model": ) -> "Model":
""" """
Takes already created instance and registers it for parent. Takes already created instance and registers it for parent.
@ -367,12 +384,11 @@ class ForeignKeyField(BaseField):
:rtype: Model :rtype: Model
""" """
if to_register: if to_register:
cls.register_relation(model=value, child=child) self.register_relation(model=value, child=child)
return value return value
@classmethod
def _construct_model_from_dict( def _construct_model_from_dict(
cls, value: dict, child: "Model", to_register: bool self, value: dict, child: "Model", to_register: bool
) -> "Model": ) -> "Model":
""" """
Takes a dictionary, creates a instance and registers it for parent. Takes a dictionary, creates a instance and registers it for parent.
@ -390,16 +406,15 @@ class ForeignKeyField(BaseField):
:return: (if needed) registered Model :return: (if needed) registered Model
:rtype: Model :rtype: Model
""" """
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: if len(value.keys()) == 1 and list(value.keys())[0] == self.to.Meta.pkname:
value["__pk_only__"] = True value["__pk_only__"] = True
model = cls.to(**value) model = self.to(**value)
if to_register: if to_register:
cls.register_relation(model=model, child=child) self.register_relation(model=model, child=child)
return model return model
@classmethod
def _construct_model_from_pk( def _construct_model_from_pk(
cls, value: Any, child: "Model", to_register: bool self, value: Any, child: "Model", to_register: bool
) -> "Model": ) -> "Model":
""" """
Takes a pk value, creates a dummy instance and registers it for parent. Takes a pk value, creates a dummy instance and registers it for parent.
@ -416,21 +431,20 @@ class ForeignKeyField(BaseField):
:return: (if needed) registered Model :return: (if needed) registered Model
:rtype: Model :rtype: Model
""" """
if cls.to.pk_type() == uuid.UUID and isinstance(value, str): # pragma: nocover if self.to.pk_type() == uuid.UUID and isinstance(value, str): # pragma: nocover
value = uuid.UUID(value) value = uuid.UUID(value)
if not isinstance(value, cls.to.pk_type()): if not isinstance(value, self.to.pk_type()):
raise RelationshipInstanceError( raise RelationshipInstanceError(
f"Relationship error - ForeignKey {cls.to.__name__} " f"Relationship error - ForeignKey {self.to.__name__} "
f"is of type {cls.to.pk_type()} " f"is of type {self.to.pk_type()} "
f"while {type(value)} passed as a parameter." f"while {type(value)} passed as a parameter."
) )
model = create_dummy_instance(fk=cls.to, pk=value) model = create_dummy_instance(fk=self.to, pk=value)
if to_register: if to_register:
cls.register_relation(model=model, child=child) self.register_relation(model=model, child=child)
return model return model
@classmethod def register_relation(self, model: "Model", child: "Model") -> None:
def register_relation(cls, model: "Model", child: "Model") -> None:
""" """
Registers relation between parent and child in relation manager. Registers relation between parent and child in relation manager.
Relation manager is kep on each model (different instance). Relation manager is kep on each model (different instance).
@ -444,11 +458,10 @@ class ForeignKeyField(BaseField):
:type child: Model class :type child: Model class
""" """
model._orm.add( model._orm.add(
parent=model, child=child, field=cls, parent=model, child=child, field=self,
) )
@classmethod def has_unresolved_forward_refs(self) -> bool:
def has_unresolved_forward_refs(cls) -> bool:
""" """
Verifies if the filed has any ForwardRefs that require updating before the Verifies if the filed has any ForwardRefs that require updating before the
model can be used. model can be used.
@ -456,11 +469,10 @@ class ForeignKeyField(BaseField):
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
return cls.to.__class__ == ForwardRef return self.to.__class__ == ForwardRef
@classmethod
def expand_relationship( def expand_relationship(
cls, self,
value: Any, value: Any,
child: Union["Model", "NewBaseModel"], child: Union["Model", "NewBaseModel"],
to_register: bool = True, to_register: bool = True,
@ -483,20 +495,19 @@ class ForeignKeyField(BaseField):
:rtype: Optional[Union["Model", List["Model"]]] :rtype: Optional[Union["Model", List["Model"]]]
""" """
if value is None: if value is None:
return None if not cls.virtual else [] return None if not self.virtual else []
constructors = { constructors = {
f"{cls.to.__name__}": cls._register_existing_model, f"{self.to.__name__}": self._register_existing_model,
"dict": cls._construct_model_from_dict, "dict": self._construct_model_from_dict,
"list": cls._extract_model_from_sequence, "list": self._extract_model_from_sequence,
} }
model = constructors.get( # type: ignore model = constructors.get( # type: ignore
value.__class__.__name__, cls._construct_model_from_pk value.__class__.__name__, self._construct_model_from_pk
)(value, child, to_register) )(value, child, to_register)
return model return model
@classmethod def get_relation_name(self) -> str: # pragma: no cover
def get_relation_name(cls) -> str: # pragma: no cover
""" """
Returns name of the relation, which can be a own name or through model Returns name of the relation, which can be a own name or through model
names for m2m models names for m2m models
@ -504,14 +515,13 @@ class ForeignKeyField(BaseField):
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
return cls.name return self.name
@classmethod def get_source_model(self) -> Type["Model"]: # pragma: no cover
def get_source_model(cls) -> Type["Model"]: # pragma: no cover
""" """
Returns model from which the relation comes -> either owner or through model Returns model from which the relation comes -> either owner or through model
:return: source model :return: source model
:rtype: Type["Model"] :rtype: Type["Model"]
""" """
return cls.owner return self.owner

View File

@ -1,5 +1,15 @@
import sys import sys
from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union, cast from typing import (
Any,
List,
Optional,
TYPE_CHECKING,
Tuple,
Type,
Union,
cast,
overload,
)
from pydantic.typing import ForwardRef, evaluate_forwardref from pydantic.typing import ForwardRef, evaluate_forwardref
import ormar # noqa: I100 import ormar # noqa: I100
@ -8,12 +18,13 @@ from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField, validate_not_allowed_fields from ormar.fields.foreign_key import ForeignKeyField, validate_not_allowed_fields
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model, T
from ormar.relations.relation_proxy import RelationProxy
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
ToType = Type["Model"] ToType = Type["T"]
else: else:
ToType = Union[Type["Model"], "ForwardRef"] ToType = Union[Type["T"], "ForwardRef"]
REF_PREFIX = "#/components/schemas/" REF_PREFIX = "#/components/schemas/"
@ -57,6 +68,16 @@ def populate_m2m_params_based_on_to_model(
return __type__, column_type return __type__, column_type
@overload
def ManyToMany(to: Type["T"], **kwargs: Any) -> "RelationProxy[T]": # pragma: no cover
...
@overload
def ManyToMany(to: ForwardRef, **kwargs: Any) -> "RelationProxy": # pragma: no cover
...
def ManyToMany( def ManyToMany(
to: "ToType", to: "ToType",
through: Optional["ToType"] = None, through: Optional["ToType"] = None,
@ -65,7 +86,7 @@ def ManyToMany(
unique: bool = False, unique: bool = False,
virtual: bool = False, virtual: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> "RelationProxy[T]":
""" """
Despite a name it's a function that returns constructed ManyToManyField. Despite a name it's a function that returns constructed ManyToManyField.
This function is actually used in model declaration This function is actually used in model declaration
@ -132,7 +153,8 @@ def ManyToMany(
related_orders_by=related_orders_by, related_orders_by=related_orders_by,
) )
return type("ManyToMany", (ManyToManyField, BaseField), namespace) Field = type("ManyToMany", (ManyToManyField, BaseField), {})
return Field(**namespace)
class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationProtocol): class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationProtocol):
@ -140,8 +162,14 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
Actual class returned from ManyToMany function call and stored in model_fields. Actual class returned from ManyToMany function call and stored in model_fields.
""" """
@classmethod def __init__(self, **kwargs: Any) -> None:
def get_source_related_name(cls) -> str: if TYPE_CHECKING: # pragma: no cover
self.__type__: type
self.to: Type["Model"]
self.through: Type["Model"]
super().__init__(**kwargs)
def get_source_related_name(self) -> str:
""" """
Returns name to use for source relation name. Returns name to use for source relation name.
For FK it's the same, differs for m2m fields. For FK it's the same, differs for m2m fields.
@ -150,32 +178,31 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
:rtype: str :rtype: str
""" """
return ( return (
cls.through.Meta.model_fields[cls.default_source_field_name()].related_name self.through.Meta.model_fields[
or cls.name self.default_source_field_name()
].related_name
or self.name
) )
@classmethod def default_target_field_name(self) -> str:
def default_target_field_name(cls) -> str:
""" """
Returns default target model name on through model. Returns default target model name on through model.
:return: name of the field :return: name of the field
:rtype: str :rtype: str
""" """
prefix = "from_" if cls.self_reference else "" prefix = "from_" if self.self_reference else ""
return f"{prefix}{cls.to.get_name()}" return f"{prefix}{self.to.get_name()}"
@classmethod def default_source_field_name(self) -> str:
def default_source_field_name(cls) -> str:
""" """
Returns default target model name on through model. Returns default target model name on through model.
:return: name of the field :return: name of the field
:rtype: str :rtype: str
""" """
prefix = "to_" if cls.self_reference else "" prefix = "to_" if self.self_reference else ""
return f"{prefix}{cls.owner.get_name()}" return f"{prefix}{self.owner.get_name()}"
@classmethod def has_unresolved_forward_refs(self) -> bool:
def has_unresolved_forward_refs(cls) -> bool:
""" """
Verifies if the filed has any ForwardRefs that require updating before the Verifies if the filed has any ForwardRefs that require updating before the
model can be used. model can be used.
@ -183,10 +210,9 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
return cls.to.__class__ == ForwardRef or cls.through.__class__ == ForwardRef return self.to.__class__ == ForwardRef or self.through.__class__ == ForwardRef
@classmethod def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None:
def evaluate_forward_ref(cls, globalns: Any, localns: Any) -> None:
""" """
Evaluates the ForwardRef to actual Field based on global and local namespaces Evaluates the ForwardRef to actual Field based on global and local namespaces
@ -197,27 +223,26 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
:return: None :return: None
:rtype: None :rtype: None
""" """
if cls.to.__class__ == ForwardRef: if self.to.__class__ == ForwardRef:
cls.to = evaluate_forwardref( self.to = evaluate_forwardref(
cls.to, # type: ignore self.to, # type: ignore
globalns, globalns,
localns or None, localns or None,
) )
(cls.__type__, cls.column_type,) = populate_m2m_params_based_on_to_model( (self.__type__, self.column_type,) = populate_m2m_params_based_on_to_model(
to=cls.to, nullable=cls.nullable, to=self.to, nullable=self.nullable,
) )
if cls.through.__class__ == ForwardRef: if self.through.__class__ == ForwardRef:
cls.through = evaluate_forwardref( self.through = evaluate_forwardref(
cls.through, # type: ignore self.through, # type: ignore
globalns, globalns,
localns or None, localns or None,
) )
forbid_through_relations(cls.through) forbid_through_relations(self.through)
@classmethod def get_relation_name(self) -> str:
def get_relation_name(cls) -> str:
""" """
Returns name of the relation, which can be a own name or through model Returns name of the relation, which can be a own name or through model
names for m2m models names for m2m models
@ -225,34 +250,32 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
if cls.self_reference and cls.name == cls.self_reference_primary: if self.self_reference and self.name == self.self_reference_primary:
return cls.default_source_field_name() return self.default_source_field_name()
return cls.default_target_field_name() return self.default_target_field_name()
@classmethod def get_source_model(self) -> Type["Model"]:
def get_source_model(cls) -> Type["Model"]:
""" """
Returns model from which the relation comes -> either owner or through model Returns model from which the relation comes -> either owner or through model
:return: source model :return: source model
:rtype: Type["Model"] :rtype: Type["Model"]
""" """
return cls.through return self.through
@classmethod def create_default_through_model(self) -> None:
def create_default_through_model(cls) -> None:
""" """
Creates default empty through model if no additional fields are required. Creates default empty through model if no additional fields are required.
""" """
owner_name = cls.owner.get_name(lower=False) owner_name = self.owner.get_name(lower=False)
to_name = cls.to.get_name(lower=False) to_name = self.to.get_name(lower=False)
class_name = f"{owner_name}{to_name}" class_name = f"{owner_name}{to_name}"
table_name = f"{owner_name.lower()}s_{to_name.lower()}s" table_name = f"{owner_name.lower()}s_{to_name.lower()}s"
new_meta_namespace = { new_meta_namespace = {
"tablename": table_name, "tablename": table_name,
"database": cls.owner.Meta.database, "database": self.owner.Meta.database,
"metadata": cls.owner.Meta.metadata, "metadata": self.owner.Meta.metadata,
} }
new_meta = type("Meta", (), new_meta_namespace) new_meta = type("Meta", (), new_meta_namespace)
through_model = type(class_name, (ormar.Model,), {"Meta": new_meta}) through_model = type(class_name, (ormar.Model,), {"Meta": new_meta})
cls.through = cast(Type["Model"], through_model) self.through = cast(Type["Model"], through_model)

View File

@ -1,7 +1,7 @@
import datetime import datetime
import decimal import decimal
import uuid import uuid
from typing import Any, Optional, TYPE_CHECKING, Type from typing import Any, Optional, TYPE_CHECKING
import pydantic import pydantic
import sqlalchemy import sqlalchemy
@ -63,7 +63,7 @@ class ModelFieldFactory:
_bases: Any = (BaseField,) _bases: Any = (BaseField,)
_type: Any = None _type: Any = None
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore def __new__(cls, *args: Any, **kwargs: Any) -> BaseField: # type: ignore
cls.validate(**kwargs) cls.validate(**kwargs)
default = kwargs.pop("default", None) default = kwargs.pop("default", None)
@ -77,7 +77,6 @@ class ModelFieldFactory:
encrypt_secret = kwargs.pop("encrypt_secret", None) encrypt_secret = kwargs.pop("encrypt_secret", None)
encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE) encrypt_backend = kwargs.pop("encrypt_backend", EncryptBackends.NONE)
encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None) encrypt_custom_backend = kwargs.pop("encrypt_custom_backend", None)
encrypt_max_length = kwargs.pop("encrypt_max_length", 5000)
namespace = dict( namespace = dict(
__type__=cls._type, __type__=cls._type,
@ -97,10 +96,10 @@ class ModelFieldFactory:
encrypt_secret=encrypt_secret, encrypt_secret=encrypt_secret,
encrypt_backend=encrypt_backend, encrypt_backend=encrypt_backend,
encrypt_custom_backend=encrypt_custom_backend, encrypt_custom_backend=encrypt_custom_backend,
encrypt_max_length=encrypt_max_length,
**kwargs **kwargs
) )
return type(cls.__name__, cls._bases, namespace) Field = type(cls.__name__, cls._bases, {})
return Field(**namespace)
@classmethod @classmethod
def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover
@ -141,7 +140,7 @@ class String(ModelFieldFactory, str):
curtail_length: int = None, curtail_length: int = None,
regex: str = None, regex: str = None,
**kwargs: Any **kwargs: Any
) -> Type[BaseField]: # type: ignore ) -> BaseField: # type: ignore
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{
@ -194,7 +193,7 @@ class Integer(ModelFieldFactory, int):
maximum: int = None, maximum: int = None,
multiple_of: int = None, multiple_of: int = None,
**kwargs: Any **kwargs: Any
) -> Type[BaseField]: ) -> BaseField:
autoincrement = kwargs.pop("autoincrement", None) autoincrement = kwargs.pop("autoincrement", None)
autoincrement = ( autoincrement = (
autoincrement autoincrement
@ -236,7 +235,7 @@ class Text(ModelFieldFactory, str):
def __new__( # type: ignore def __new__( # type: ignore
cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any
) -> Type[BaseField]: ) -> BaseField:
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{
@ -276,7 +275,7 @@ class Float(ModelFieldFactory, float):
maximum: float = None, maximum: float = None,
multiple_of: int = None, multiple_of: int = None,
**kwargs: Any **kwargs: Any
) -> Type[BaseField]: ) -> BaseField:
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{
@ -430,7 +429,7 @@ class BigInteger(Integer, int):
maximum: int = None, maximum: int = None,
multiple_of: int = None, multiple_of: int = None,
**kwargs: Any **kwargs: Any
) -> Type[BaseField]: ) -> BaseField:
autoincrement = kwargs.pop("autoincrement", None) autoincrement = kwargs.pop("autoincrement", None)
autoincrement = ( autoincrement = (
autoincrement autoincrement
@ -481,7 +480,7 @@ class Decimal(ModelFieldFactory, decimal.Decimal):
max_digits: int = None, max_digits: int = None,
decimal_places: int = None, decimal_places: int = None,
**kwargs: Any **kwargs: Any
) -> Type[BaseField]: ) -> BaseField:
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{
@ -544,7 +543,7 @@ class UUID(ModelFieldFactory, uuid.UUID):
def __new__( # type: ignore # noqa CFQ002 def __new__( # type: ignore # noqa CFQ002
cls, *, uuid_format: str = "hex", **kwargs: Any cls, *, uuid_format: str = "hex", **kwargs: Any
) -> Type[BaseField]: ) -> BaseField:
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{

View File

@ -133,7 +133,7 @@ class EncryptedString(types.TypeDecorator):
raise ModelDefinitionError("Wrong or no encrypt backend provided!") raise ModelDefinitionError("Wrong or no encrypt backend provided!")
self.backend: EncryptBackend = backend() self.backend: EncryptBackend = backend()
self._field_type: Type["BaseField"] = _field_type self._field_type: "BaseField" = _field_type
self._underlying_type: Any = _field_type.column_type self._underlying_type: Any = _field_type.column_type
self._key: Union[str, Callable] = encrypt_secret self._key: Union[str, Callable] = encrypt_secret
type_ = self._field_type.__type__ type_ = self._field_type.__type__

View File

@ -57,7 +57,8 @@ def Through( # noqa CFQ002
is_through=True, is_through=True,
) )
return type("Through", (ThroughField, BaseField), namespace) Field = type("Through", (ThroughField, BaseField), {})
return Field(**namespace)
class ThroughField(ForeignKeyField): class ThroughField(ForeignKeyField):

View File

@ -6,7 +6,7 @@ ass well as vast number of helper functions for pydantic, sqlalchemy and relatio
from ormar.models.newbasemodel import NewBaseModel # noqa I100 from ormar.models.newbasemodel import NewBaseModel # noqa I100
from ormar.models.model_row import ModelRow # noqa I100 from ormar.models.model_row import ModelRow # noqa I100
from ormar.models.model import Model # noqa I100 from ormar.models.model import Model, T # noqa I100
from ormar.models.excludable import ExcludableItems # noqa I100 from ormar.models.excludable import ExcludableItems # noqa I100
__all__ = ["NewBaseModel", "Model", "ModelRow", "ExcludableItems"] __all__ = ["NewBaseModel", "Model", "ModelRow", "ExcludableItems", "T"]

View File

@ -12,7 +12,7 @@ if TYPE_CHECKING: # pragma no cover
from ormar.fields import BaseField from ormar.fields import BaseField
def is_field_an_forward_ref(field: Type["BaseField"]) -> bool: def is_field_an_forward_ref(field: "BaseField") -> bool:
""" """
Checks if field is a relation field and whether any of the referenced models Checks if field is a relation field and whether any of the referenced models
are ForwardRefs that needs to be updated before proceeding. are ForwardRefs that needs to be updated before proceeding.

View File

@ -1,12 +1,10 @@
import warnings
from typing import Dict, Optional, TYPE_CHECKING, Tuple, Type from typing import Dict, Optional, TYPE_CHECKING, Tuple, Type
import pydantic import pydantic
from pydantic.fields import ModelField from pydantic.fields import ModelField
from pydantic.utils import lenient_issubclass from pydantic.utils import lenient_issubclass
import ormar # noqa: I100, I202 from ormar.fields import BaseField # noqa: I100, I202
from ormar.fields import BaseField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -14,7 +12,7 @@ if TYPE_CHECKING: # pragma no cover
def create_pydantic_field( def create_pydantic_field(
field_name: str, model: Type["Model"], model_field: Type["ManyToManyField"] field_name: str, model: Type["Model"], model_field: "ManyToManyField"
) -> None: ) -> None:
""" """
Registers pydantic field on through model that leads to passed model Registers pydantic field on through model that leads to passed model
@ -59,38 +57,6 @@ def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField":
) )
def populate_default_pydantic_field_value(
ormar_field: Type["BaseField"], field_name: str, attrs: dict
) -> dict:
"""
Grabs current value of the ormar Field in class namespace
(so the default_value declared on ormar model if set)
and converts it to pydantic.FieldInfo
that pydantic is able to extract later.
On FieldInfo there are saved all needed params like max_length of the string
and other constraints that pydantic can use to build
it's own field validation used by ormar.
:param ormar_field: field to convert
:type ormar_field: ormar Field
:param field_name: field to convert name
:type field_name: str
:param attrs: current class namespace
:type attrs: Dict
:return: updated namespace dict
:rtype: Dict
"""
curr_def_value = attrs.get(field_name, ormar.Undefined)
if lenient_issubclass(curr_def_value, ormar.fields.BaseField):
curr_def_value = ormar.Undefined
if curr_def_value is None:
attrs[field_name] = ormar_field.convert_to_pydantic_field_info(allow_null=True)
else:
attrs[field_name] = ormar_field.convert_to_pydantic_field_info()
return attrs
def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]: def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
""" """
Extracts ormar fields from annotations (deprecated) and from namespace Extracts ormar fields from annotations (deprecated) and from namespace
@ -110,22 +76,11 @@ def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
:rtype: Tuple[Dict, Dict] :rtype: Tuple[Dict, Dict]
""" """
model_fields = {} model_fields = {}
potential_fields = { potential_fields = {}
k: v
for k, v in attrs["__annotations__"].items()
if lenient_issubclass(v, BaseField)
}
if potential_fields:
warnings.warn(
"Using ormar.Fields as type Model annotation has been deprecated,"
" check documentation of current version",
DeprecationWarning,
)
potential_fields.update(get_potential_fields(attrs)) potential_fields.update(get_potential_fields(attrs))
for field_name, field in potential_fields.items(): for field_name, field in potential_fields.items():
field.name = field_name field.name = field_name
attrs = populate_default_pydantic_field_value(field, field_name, attrs)
model_fields[field_name] = field model_fields[field_name] = field
attrs["__annotations__"][field_name] = ( attrs["__annotations__"][field_name] = (
field.__type__ if not field.nullable else Optional[field.__type__] field.__type__ if not field.nullable else Optional[field.__type__]
@ -156,4 +111,8 @@ def get_potential_fields(attrs: Dict) -> Dict:
:return: extracted fields that are ormar Fields :return: extracted fields that are ormar Fields
:rtype: Dict :rtype: Dict
""" """
return {k: v for k, v in attrs.items() if lenient_issubclass(v, BaseField)} return {
k: v
for k, v in attrs.items()
if (lenient_issubclass(v, BaseField) or isinstance(v, BaseField))
}

View File

@ -13,7 +13,7 @@ if TYPE_CHECKING: # pragma no cover
alias_manager = AliasManager() alias_manager = AliasManager()
def register_relation_on_build(field: Type["ForeignKeyField"]) -> None: def register_relation_on_build(field: "ForeignKeyField") -> None:
""" """
Registers ForeignKey relation in alias_manager to set a table_prefix. Registers ForeignKey relation in alias_manager to set a table_prefix.
Registration include also reverse relation side to be able to join both sides. Registration include also reverse relation side to be able to join both sides.
@ -32,7 +32,7 @@ def register_relation_on_build(field: Type["ForeignKeyField"]) -> None:
) )
def register_many_to_many_relation_on_build(field: Type["ManyToManyField"]) -> None: def register_many_to_many_relation_on_build(field: "ManyToManyField") -> None:
""" """
Registers connection between through model and both sides of the m2m relation. Registers connection between through model and both sides of the m2m relation.
Registration include also reverse relation side to be able to join both sides. Registration include also reverse relation side to be able to join both sides.
@ -58,7 +58,7 @@ def register_many_to_many_relation_on_build(field: Type["ManyToManyField"]) -> N
) )
def expand_reverse_relationship(model_field: Type["ForeignKeyField"]) -> None: def expand_reverse_relationship(model_field: "ForeignKeyField") -> None:
""" """
If the reverse relation has not been set before it's set here. If the reverse relation has not been set before it's set here.
@ -84,11 +84,11 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
model_fields = list(model.Meta.model_fields.values()) model_fields = list(model.Meta.model_fields.values())
for model_field in model_fields: for model_field in model_fields:
if model_field.is_relation and not model_field.has_unresolved_forward_refs(): if model_field.is_relation and not model_field.has_unresolved_forward_refs():
model_field = cast(Type["ForeignKeyField"], model_field) model_field = cast("ForeignKeyField", model_field)
expand_reverse_relationship(model_field=model_field) expand_reverse_relationship(model_field=model_field)
def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None: def register_reverse_model_fields(model_field: "ForeignKeyField") -> None:
""" """
Registers reverse ForeignKey field on related model. Registers reverse ForeignKey field on related model.
By default it's name.lower()+'s' of the model on which relation is defined. By default it's name.lower()+'s' of the model on which relation is defined.
@ -101,7 +101,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
""" """
related_name = model_field.get_related_name() related_name = model_field.get_related_name()
if model_field.is_multi: if model_field.is_multi:
model_field.to.Meta.model_fields[related_name] = ManyToMany( model_field.to.Meta.model_fields[related_name] = ManyToMany( # type: ignore
model_field.owner, model_field.owner,
through=model_field.through, through=model_field.through,
name=related_name, name=related_name,
@ -113,11 +113,11 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
orders_by=model_field.related_orders_by, orders_by=model_field.related_orders_by,
) )
# register foreign keys on through model # register foreign keys on through model
model_field = cast(Type["ManyToManyField"], model_field) model_field = cast("ManyToManyField", model_field)
register_through_shortcut_fields(model_field=model_field) register_through_shortcut_fields(model_field=model_field)
adjust_through_many_to_many_model(model_field=model_field) adjust_through_many_to_many_model(model_field=model_field)
else: else:
model_field.to.Meta.model_fields[related_name] = ForeignKey( model_field.to.Meta.model_fields[related_name] = ForeignKey( # type: ignore
model_field.owner, model_field.owner,
real_name=related_name, real_name=related_name,
virtual=True, virtual=True,
@ -128,7 +128,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
) )
def register_through_shortcut_fields(model_field: Type["ManyToManyField"]) -> None: def register_through_shortcut_fields(model_field: "ManyToManyField") -> None:
""" """
Registers m2m relation through shortcut on both ends of the relation. Registers m2m relation through shortcut on both ends of the relation.
@ -156,7 +156,7 @@ def register_through_shortcut_fields(model_field: Type["ManyToManyField"]) -> No
) )
def register_relation_in_alias_manager(field: Type["ForeignKeyField"]) -> None: def register_relation_in_alias_manager(field: "ForeignKeyField") -> None:
""" """
Registers the relation (and reverse relation) in alias manager. Registers the relation (and reverse relation) in alias manager.
The m2m relations require registration of through model between The m2m relations require registration of through model between
@ -172,7 +172,7 @@ def register_relation_in_alias_manager(field: Type["ForeignKeyField"]) -> None:
if field.is_multi: if field.is_multi:
if field.has_unresolved_forward_refs(): if field.has_unresolved_forward_refs():
return return
field = cast(Type["ManyToManyField"], field) field = cast("ManyToManyField", field)
register_many_to_many_relation_on_build(field=field) register_many_to_many_relation_on_build(field=field)
elif field.is_relation and not field.is_through: elif field.is_relation and not field.is_through:
if field.has_unresolved_forward_refs(): if field.has_unresolved_forward_refs():
@ -181,7 +181,7 @@ def register_relation_in_alias_manager(field: Type["ForeignKeyField"]) -> None:
def verify_related_name_dont_duplicate( def verify_related_name_dont_duplicate(
related_name: str, model_field: Type["ForeignKeyField"] related_name: str, model_field: "ForeignKeyField"
) -> None: ) -> None:
""" """
Verifies whether the used related_name (regardless of the fact if user defined or Verifies whether the used related_name (regardless of the fact if user defined or
@ -213,7 +213,7 @@ def verify_related_name_dont_duplicate(
) )
def reverse_field_not_already_registered(model_field: Type["ForeignKeyField"]) -> bool: def reverse_field_not_already_registered(model_field: "ForeignKeyField") -> bool:
""" """
Checks if child is already registered in parents pydantic fields. Checks if child is already registered in parents pydantic fields.

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import NewBaseModel from ormar.models import NewBaseModel
def adjust_through_many_to_many_model(model_field: Type["ManyToManyField"]) -> None: def adjust_through_many_to_many_model(model_field: "ManyToManyField") -> None:
""" """
Registers m2m relation on through model. Registers m2m relation on through model.
Sets ormar.ForeignKey from through model to both child and parent models. Sets ormar.ForeignKey from through model to both child and parent models.
@ -26,14 +26,14 @@ def adjust_through_many_to_many_model(model_field: Type["ManyToManyField"]) -> N
""" """
parent_name = model_field.default_target_field_name() parent_name = model_field.default_target_field_name()
child_name = model_field.default_source_field_name() child_name = model_field.default_source_field_name()
model_fields = model_field.through.Meta.model_fields
model_field.through.Meta.model_fields[parent_name] = ormar.ForeignKey( model_fields[parent_name] = ormar.ForeignKey( # type: ignore
model_field.to, model_field.to,
real_name=parent_name, real_name=parent_name,
ondelete="CASCADE", ondelete="CASCADE",
owner=model_field.through, owner=model_field.through,
) )
model_field.through.Meta.model_fields[child_name] = ormar.ForeignKey( model_fields[child_name] = ormar.ForeignKey( # type: ignore
model_field.owner, model_field.owner,
real_name=child_name, real_name=child_name,
ondelete="CASCADE", ondelete="CASCADE",
@ -52,7 +52,7 @@ def adjust_through_many_to_many_model(model_field: Type["ManyToManyField"]) -> N
def create_and_append_m2m_fk( def create_and_append_m2m_fk(
model: Type["Model"], model_field: Type["ManyToManyField"], field_name: str model: Type["Model"], model_field: "ManyToManyField", field_name: str
) -> None: ) -> None:
""" """
Registers sqlalchemy Column with sqlalchemy.ForeignKey leading to the model. Registers sqlalchemy Column with sqlalchemy.ForeignKey leading to the model.
@ -190,22 +190,22 @@ def _process_fields(
return pkname, columns return pkname, columns
def _is_through_model_not_set(field: Type["BaseField"]) -> bool: def _is_through_model_not_set(field: "BaseField") -> bool:
""" """
Alias to if check that verifies if through model was created. Alias to if check that verifies if through model was created.
:param field: field to check :param field: field to check
:type field: Type["BaseField"] :type field: "BaseField"
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
return field.is_multi and not field.through return field.is_multi and not field.through
def _is_db_field(field: Type["BaseField"]) -> bool: def _is_db_field(field: "BaseField") -> bool:
""" """
Alias to if check that verifies if field should be included in database. Alias to if check that verifies if field should be included in database.
:param field: field to check :param field: field to check
:type field: Type["BaseField"] :type field: "BaseField"
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
@ -298,7 +298,7 @@ def populate_meta_sqlalchemy_table_if_required(meta: "ModelMeta") -> None:
def update_column_definition( def update_column_definition(
model: Union[Type["Model"], Type["NewBaseModel"]], field: Type["ForeignKeyField"] model: Union[Type["Model"], Type["NewBaseModel"]], field: "ForeignKeyField"
) -> None: ) -> None:
""" """
Updates a column with a new type column based on updated parameters in FK fields. Updates a column with a new type column based on updated parameters in FK fields.
@ -306,7 +306,7 @@ def update_column_definition(
:param model: model on which columns needs to be updated :param model: model on which columns needs to be updated
:type model: Type["Model"] :type model: Type["Model"]
:param field: field with column definition that requires update :param field: field with column definition that requires update
:type field: Type[ForeignKeyField] :type field: ForeignKeyField
:return: None :return: None
:rtype: None :rtype: None
""" """

View File

@ -20,7 +20,7 @@ if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
def check_if_field_has_choices(field: Type[BaseField]) -> bool: def check_if_field_has_choices(field: BaseField) -> bool:
""" """
Checks if given field has choices populated. Checks if given field has choices populated.
A if it has one, a validator for this field needs to be attached. A if it has one, a validator for this field needs to be attached.
@ -34,7 +34,7 @@ def check_if_field_has_choices(field: Type[BaseField]) -> bool:
def convert_choices_if_needed( # noqa: CCR001 def convert_choices_if_needed( # noqa: CCR001
field: Type["BaseField"], value: Any field: "BaseField", value: Any
) -> Tuple[Any, List]: ) -> Tuple[Any, List]:
""" """
Converts dates to isoformat as fastapi can check this condition in routes Converts dates to isoformat as fastapi can check this condition in routes
@ -47,7 +47,7 @@ def convert_choices_if_needed( # noqa: CCR001
Converts decimal to float with given scale. Converts decimal to float with given scale.
:param field: ormar field to check with choices :param field: ormar field to check with choices
:type field: Type[BaseField] :type field: BaseField
:param values: current values of the model to verify :param values: current values of the model to verify
:type values: Dict :type values: Dict
:return: value, choices list :return: value, choices list
@ -77,13 +77,13 @@ def convert_choices_if_needed( # noqa: CCR001
return value, choices return value, choices
def validate_choices(field: Type["BaseField"], value: Any) -> None: def validate_choices(field: "BaseField", value: Any) -> None:
""" """
Validates if given value is in provided choices. Validates if given value is in provided choices.
:raises ValueError: If value is not in choices. :raises ValueError: If value is not in choices.
:param field:field to validate :param field:field to validate
:type field: Type[BaseField] :type field: BaseField
:param value: value of the field :param value: value of the field
:type value: Any :type value: Any
""" """

View File

@ -18,6 +18,7 @@ from sqlalchemy.sql.schema import ColumnCollectionConstraint
import ormar # noqa I100 import ormar # noqa I100
from ormar import ModelDefinitionError # noqa I100 from ormar import ModelDefinitionError # noqa I100
from ormar.exceptions import ModelError
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
@ -44,6 +45,7 @@ from ormar.signals import Signal, SignalEmitter
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.models import T
CONFIG_KEY = "Config" CONFIG_KEY = "Config"
PARSED_FIELDS_KEY = "__parsed_fields__" PARSED_FIELDS_KEY = "__parsed_fields__"
@ -63,9 +65,7 @@ class ModelMeta:
columns: List[sqlalchemy.Column] columns: List[sqlalchemy.Column]
constraints: List[ColumnCollectionConstraint] constraints: List[ColumnCollectionConstraint]
pkname: str pkname: str
model_fields: Dict[ model_fields: Dict[str, Union[BaseField, ForeignKeyField, ManyToManyField]]
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
]
alias_manager: AliasManager alias_manager: AliasManager
property_fields: Set property_fields: Set
signals: SignalEmitter signals: SignalEmitter
@ -215,7 +215,7 @@ def update_attrs_from_base_meta( # noqa: CCR001
def copy_and_replace_m2m_through_model( # noqa: CFQ002 def copy_and_replace_m2m_through_model( # noqa: CFQ002
field: Type[ManyToManyField], field: ManyToManyField,
field_name: str, field_name: str,
table_name: str, table_name: str,
parent_fields: Dict, parent_fields: Dict,
@ -238,7 +238,7 @@ def copy_and_replace_m2m_through_model( # noqa: CFQ002
:param base_class: base class model :param base_class: base class model
:type base_class: Type["Model"] :type base_class: Type["Model"]
:param field: field with relations definition :param field: field with relations definition
:type field: Type[ManyToManyField] :type field: ManyToManyField
:param field_name: name of the relation field :param field_name: name of the relation field
:type field_name: str :type field_name: str
:param table_name: name of the table :param table_name: name of the table
@ -250,9 +250,10 @@ def copy_and_replace_m2m_through_model( # noqa: CFQ002
:param meta: metaclass of currently created model :param meta: metaclass of currently created model
:type meta: ModelMeta :type meta: ModelMeta
""" """
copy_field: Type[BaseField] = type( # type: ignore Field: Type[BaseField] = type( # type: ignore
field.__name__, (ManyToManyField, BaseField), dict(field.__dict__) field.__class__.__name__, (ManyToManyField, BaseField), {}
) )
copy_field = Field(**dict(field.__dict__))
related_name = field.related_name + "_" + table_name related_name = field.related_name + "_" + table_name
copy_field.related_name = related_name # type: ignore copy_field.related_name = related_name # type: ignore
@ -293,9 +294,7 @@ def copy_data_from_parent_model( # noqa: CCR001
base_class: Type["Model"], base_class: Type["Model"],
curr_class: type, curr_class: type,
attrs: Dict, attrs: Dict,
model_fields: Dict[ model_fields: Dict[str, Union[BaseField, ForeignKeyField, ManyToManyField]],
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
],
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
""" """
Copy the key parameters [databse, metadata, property_fields and constraints] Copy the key parameters [databse, metadata, property_fields and constraints]
@ -342,7 +341,7 @@ def copy_data_from_parent_model( # noqa: CCR001
) )
for field_name, field in base_class.Meta.model_fields.items(): for field_name, field in base_class.Meta.model_fields.items():
if field.is_multi: if field.is_multi:
field = cast(Type["ManyToManyField"], field) field = cast(ManyToManyField, field)
copy_and_replace_m2m_through_model( copy_and_replace_m2m_through_model(
field=field, field=field,
field_name=field_name, field_name=field_name,
@ -354,9 +353,10 @@ def copy_data_from_parent_model( # noqa: CCR001
) )
elif field.is_relation and field.related_name: elif field.is_relation and field.related_name:
copy_field = type( # type: ignore Field = type( # type: ignore
field.__name__, (ForeignKeyField, BaseField), dict(field.__dict__) field.__class__.__name__, (ForeignKeyField, BaseField), {}
) )
copy_field = Field(**dict(field.__dict__))
related_name = field.related_name + "_" + table_name related_name = field.related_name + "_" + table_name
copy_field.related_name = related_name # type: ignore copy_field.related_name = related_name # type: ignore
parent_fields[field_name] = copy_field parent_fields[field_name] = copy_field
@ -372,9 +372,7 @@ def extract_from_parents_definition( # noqa: CCR001
base_class: type, base_class: type,
curr_class: type, curr_class: type,
attrs: Dict, attrs: Dict,
model_fields: Dict[ model_fields: Dict[str, Union[BaseField, ForeignKeyField, ManyToManyField]],
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
],
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
""" """
Extracts fields from base classes if they have valid oramr fields. Extracts fields from base classes if they have valid oramr fields.
@ -549,6 +547,15 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
field_name=field_name, model=new_model field_name=field_name, model=new_model
) )
new_model.Meta.alias_manager = alias_manager new_model.Meta.alias_manager = alias_manager
new_model.objects = QuerySet(new_model)
return new_model return new_model
@property
def objects(cls: Type["T"]) -> "QuerySet[T]": # type: ignore
if cls.Meta.requires_ref_update:
raise ModelError(
f"Model {cls.get_name()} has not updated "
f"ForwardRefs. \nBefore using the model you "
f"need to call update_forward_refs()."
)
return QuerySet(model_cls=cls)

View File

@ -67,6 +67,6 @@ class AliasMixin:
:rtype: Dict :rtype: Dict
""" """
for field_name, field in cls.Meta.model_fields.items(): for field_name, field in cls.Meta.model_fields.items():
if field.alias and field.alias in new_kwargs: if field.get_alias() and field.get_alias() in new_kwargs:
new_kwargs[field_name] = new_kwargs.pop(field.alias) new_kwargs[field_name] = new_kwargs.pop(field.get_alias())
return new_kwargs return new_kwargs

View File

@ -41,7 +41,7 @@ class PrefetchQueryMixin(RelationMixin):
field_name = parent_model.Meta.model_fields[related].get_related_name() field_name = parent_model.Meta.model_fields[related].get_related_name()
field = target_model.Meta.model_fields[field_name] field = target_model.Meta.model_fields[field_name]
if field.is_multi: if field.is_multi:
field = cast(Type["ManyToManyField"], field) field = cast("ManyToManyField", field)
field_name = field.default_target_field_name() field_name = field.default_target_field_name()
sub_field = field.through.Meta.model_fields[field_name] sub_field = field.through.Meta.model_fields[field_name]
return field.through, sub_field.get_alias() return field.through, sub_field.get_alias()
@ -78,7 +78,7 @@ class PrefetchQueryMixin(RelationMixin):
return column.get_alias() if use_raw else column.name return column.get_alias() if use_raw else column.name
@classmethod @classmethod
def get_related_field_name(cls, target_field: Type["ForeignKeyField"]) -> str: def get_related_field_name(cls, target_field: "ForeignKeyField") -> str:
""" """
Returns name of the relation field that should be used in prefetch query. Returns name of the relation field that should be used in prefetch query.
This field is later used to register relation in prefetch query, This field is later used to register relation in prefetch query,

View File

@ -1,4 +1,3 @@
import inspect
from typing import ( from typing import (
Callable, Callable,
List, List,
@ -9,6 +8,8 @@ from typing import (
Union, Union,
) )
from ormar import BaseField
class RelationMixin: class RelationMixin:
""" """
@ -85,7 +86,11 @@ class RelationMixin:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and field.is_relation and not field.is_through: if (
isinstance(field, BaseField)
and field.is_relation
and not field.is_through
):
related_names.add(name) related_names.add(name)
cls._related_names = related_names cls._related_names = related_names

View File

@ -15,8 +15,6 @@ from ormar.models import NewBaseModel # noqa I100
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
from ormar.models.model_row import ModelRow from ormar.models.model_row import ModelRow
if TYPE_CHECKING: # pragma nocover
from ormar import QuerySet
T = TypeVar("T", bound="Model") T = TypeVar("T", bound="Model")
@ -25,7 +23,6 @@ class Model(ModelRow):
__abstract__ = False __abstract__ = False
if TYPE_CHECKING: # pragma nocover if TYPE_CHECKING: # pragma nocover
Meta: ModelMeta Meta: ModelMeta
objects: "QuerySet"
def __repr__(self) -> str: # pragma nocover def __repr__(self) -> str: # pragma nocover
_repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()} _repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()}

View File

@ -29,7 +29,7 @@ class ModelRow(NewBaseModel):
source_model: Type["Model"], source_model: Type["Model"],
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
related_field: Type["ForeignKeyField"] = None, related_field: "ForeignKeyField" = None,
excludable: ExcludableItems = None, excludable: ExcludableItems = None,
current_relation_str: str = "", current_relation_str: str = "",
proxy_source_model: Optional[Type["Model"]] = None, proxy_source_model: Optional[Type["Model"]] = None,
@ -65,7 +65,7 @@ class ModelRow(NewBaseModel):
:param related_models: list or dict of related models :param related_models: list or dict of related models
:type related_models: Union[List, Dict] :type related_models: Union[List, Dict]
:param related_field: field with relation declaration :param related_field: field with relation declaration
:type related_field: Type[ForeignKeyField] :type related_field: ForeignKeyField
:return: returns model if model is populated from database :return: returns model if model is populated from database
:rtype: Optional[Model] :rtype: Optional[Model]
""" """
@ -116,7 +116,7 @@ class ModelRow(NewBaseModel):
cls, cls,
source_model: Type["Model"], source_model: Type["Model"],
current_relation_str: str, current_relation_str: str,
related_field: Type["ForeignKeyField"], related_field: "ForeignKeyField",
used_prefixes: List[str], used_prefixes: List[str],
) -> str: ) -> str:
""" """
@ -126,7 +126,7 @@ class ModelRow(NewBaseModel):
:param current_relation_str: current relation string :param current_relation_str: current relation string
:type current_relation_str: str :type current_relation_str: str
:param related_field: field with relation declaration :param related_field: field with relation declaration
:type related_field: Type["ForeignKeyField"] :type related_field: "ForeignKeyField"
:param used_prefixes: list of already extracted prefixes :param used_prefixes: list of already extracted prefixes
:type used_prefixes: List[str] :type used_prefixes: List[str]
:return: table_prefix to use :return: table_prefix to use
@ -193,7 +193,7 @@ class ModelRow(NewBaseModel):
for related in related_models: for related in related_models:
field = cls.Meta.model_fields[related] field = cls.Meta.model_fields[related]
field = cast(Type["ForeignKeyField"], field) field = cast("ForeignKeyField", field)
model_cls = field.to model_cls = field.to
model_excludable = excludable.get( model_excludable = excludable.get(
model_cls=cast(Type["Model"], cls), alias=table_prefix model_cls=cast(Type["Model"], cls), alias=table_prefix

View File

@ -67,7 +67,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
__slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column") __slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column")
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, Type[BaseField]] pk: Any
__model_fields__: Dict[str, BaseField]
__table__: sqlalchemy.Table __table__: sqlalchemy.Table
__fields__: Dict[str, pydantic.fields.ModelField] __fields__: Dict[str, pydantic.fields.ModelField]
__pydantic_model__: Type[BaseModel] __pydantic_model__: Type[BaseModel]
@ -77,6 +78,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
__database__: databases.Database __database__: databases.Database
_orm_relationship_manager: AliasManager _orm_relationship_manager: AliasManager
_orm: RelationsManager _orm: RelationsManager
_orm_id: int
_orm_saved: bool _orm_saved: bool
_related_names: Optional[Set] _related_names: Optional[Set]
_related_names_hash: str _related_names_hash: str
@ -265,13 +267,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if item == "pk": if item == "pk":
return object.__getattribute__(self, "__dict__").get(self.Meta.pkname, None) return object.__getattribute__(self, "__dict__").get(self.Meta.pkname, None)
if item in object.__getattribute__(self, "extract_related_names")(): if item in object.__getattribute__(self, "extract_related_names")():
return object.__getattribute__( return self._extract_related_model_instead_of_field(item)
self, "_extract_related_model_instead_of_field"
)(item)
if item in object.__getattribute__(self, "extract_through_names")(): if item in object.__getattribute__(self, "extract_through_names")():
return object.__getattribute__( return self._extract_related_model_instead_of_field(item)
self, "_extract_related_model_instead_of_field"
)(item)
if item in object.__getattribute__(self, "Meta").property_fields: if item in object.__getattribute__(self, "Meta").property_fields:
value = object.__getattribute__(self, item) value = object.__getattribute__(self, item)
return value() if callable(value) else value return value() if callable(value) else value
@ -455,7 +453,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
fields_to_check = cls.Meta.model_fields.copy() fields_to_check = cls.Meta.model_fields.copy()
for field in fields_to_check.values(): for field in fields_to_check.values():
if field.has_unresolved_forward_refs(): if field.has_unresolved_forward_refs():
field = cast(Type[ForeignKeyField], field) field = cast(ForeignKeyField, field)
field.evaluate_forward_ref(globalns=globalns, localns=localns) field.evaluate_forward_ref(globalns=globalns, localns=localns)
field.set_self_reference_flag() field.set_self_reference_flag()
expand_reverse_relationship(model_field=field) expand_reverse_relationship(model_field=field)
@ -747,12 +745,12 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
) )
return self_fields return self_fields
def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]: def get_relation_model_id(self, target_field: "BaseField") -> Optional[int]:
""" """
Returns an id of the relation side model to use in prefetch query. Returns an id of the relation side model to use in prefetch query.
:param target_field: field with relation definition :param target_field: field with relation definition
:type target_field: Type["BaseField"] :type target_field: "BaseField"
:return: value of pk if set :return: value of pk if set
:rtype: Optional[int] :rtype: Optional[int]
""" """

View File

@ -23,6 +23,7 @@ quick_access_set = {
"_extract_nested_models", "_extract_nested_models",
"_extract_nested_models_from_list", "_extract_nested_models_from_list",
"_extract_own_model_fields", "_extract_own_model_fields",
"_extract_related_model_instead_of_field",
"_get_related_not_excluded_fields", "_get_related_not_excluded_fields",
"_get_value", "_get_value",
"_is_conversion_to_json_needed", "_is_conversion_to_json_needed",

View File

@ -292,7 +292,7 @@ class PrefetchQuery:
for related in related_to_extract: for related in related_to_extract:
target_field = model.Meta.model_fields[related] target_field = model.Meta.model_fields[related]
target_field = cast(Type["ForeignKeyField"], target_field) target_field = cast("ForeignKeyField", target_field)
target_model = target_field.to.get_name() target_model = target_field.to.get_name()
model_id = model.get_relation_model_id(target_field=target_field) model_id = model.get_relation_model_id(target_field=target_field)
@ -394,7 +394,7 @@ class PrefetchQuery:
:rtype: None :rtype: None
""" """
target_field = target_model.Meta.model_fields[related] target_field = target_model.Meta.model_fields[related]
target_field = cast(Type["ForeignKeyField"], target_field) target_field = cast("ForeignKeyField", target_field)
reverse = False reverse = False
if target_field.virtual or target_field.is_multi: if target_field.virtual or target_field.is_multi:
reverse = True reverse = True
@ -461,7 +461,7 @@ class PrefetchQuery:
async def _run_prefetch_query( async def _run_prefetch_query(
self, self,
target_field: Type["BaseField"], target_field: "BaseField",
excludable: "ExcludableItems", excludable: "ExcludableItems",
filter_clauses: List, filter_clauses: List,
related_field_name: str, related_field_name: str,
@ -474,7 +474,7 @@ class PrefetchQuery:
models. models.
:param target_field: ormar field with relation definition :param target_field: ormar field with relation definition
:type target_field: Type["BaseField"] :type target_field: "BaseField"
:param filter_clauses: list of clauses, actually one clause with ids of relation :param filter_clauses: list of clauses, actually one clause with ids of relation
:type filter_clauses: List[sqlalchemy.sql.elements.TextClause] :type filter_clauses: List[sqlalchemy.sql.elements.TextClause]
:return: table prefix and raw rows from sql response :return: table prefix and raw rows from sql response
@ -540,13 +540,13 @@ class PrefetchQuery:
) )
def _update_already_loaded_rows( # noqa: CFQ002 def _update_already_loaded_rows( # noqa: CFQ002
self, target_field: Type["BaseField"], prefetch_dict: Dict, orders_by: Dict, self, target_field: "BaseField", prefetch_dict: Dict, orders_by: Dict,
) -> None: ) -> None:
""" """
Updates models that are already loaded, usually children of children. Updates models that are already loaded, usually children of children.
:param target_field: ormar field with relation definition :param target_field: ormar field with relation definition
:type target_field: Type["BaseField"] :type target_field: "BaseField"
:param prefetch_dict: dictionaries of related models to prefetch :param prefetch_dict: dictionaries of related models to prefetch
:type prefetch_dict: Dict :type prefetch_dict: Dict
:param orders_by: dictionary of order by clauses by model :param orders_by: dictionary of order by clauses by model
@ -561,7 +561,7 @@ class PrefetchQuery:
def _populate_rows( # noqa: CFQ002 def _populate_rows( # noqa: CFQ002
self, self,
rows: List, rows: List,
target_field: Type["ForeignKeyField"], target_field: "ForeignKeyField",
parent_model: Type["Model"], parent_model: Type["Model"],
table_prefix: str, table_prefix: str,
exclude_prefix: str, exclude_prefix: str,
@ -584,7 +584,7 @@ class PrefetchQuery:
:param rows: raw sql response from the prefetch query :param rows: raw sql response from the prefetch query
:type rows: List[sqlalchemy.engine.result.RowProxy] :type rows: List[sqlalchemy.engine.result.RowProxy]
:param target_field: field with relation definition from parent model :param target_field: field with relation definition from parent model
:type target_field: Type["BaseField"] :type target_field: "BaseField"
:param parent_model: model with relation definition :param parent_model: model with relation definition
:type parent_model: Type[Model] :type parent_model: Type[Model]
:param table_prefix: prefix of the target table from current relation :param table_prefix: prefix of the target table from current relation

View File

@ -1,12 +1,14 @@
from typing import ( from typing import (
Any, Any,
Dict, Dict,
Generic,
List, List,
Optional, Optional,
Sequence, Sequence,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar,
Union, Union,
cast, cast,
) )
@ -17,7 +19,7 @@ from sqlalchemy import bindparam
import ormar # noqa I100 import ormar # noqa I100
from ormar import MultipleMatches, NoMatch from ormar import MultipleMatches, NoMatch
from ormar.exceptions import ModelError, ModelPersistenceError, QueryDefinitionError from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
from ormar.queryset import FilterQuery, SelectAction from ormar.queryset import FilterQuery, SelectAction
from ormar.queryset.actions.order_action import OrderAction from ormar.queryset.actions.order_action import OrderAction
from ormar.queryset.clause import FilterGroup, QueryClause from ormar.queryset.clause import FilterGroup, QueryClause
@ -26,19 +28,21 @@ from ormar.queryset.query import Query
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.models import T
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
from ormar.relations.querysetproxy import QuerysetProxy
from ormar.models.excludable import ExcludableItems from ormar.models.excludable import ExcludableItems
else:
T = TypeVar("T", bound="Model")
class QuerySet: class QuerySet(Generic[T]):
""" """
Main class to perform database queries, exposed on each model as objects attribute. Main class to perform database queries, exposed on each model as objects attribute.
""" """
def __init__( # noqa CFQ002 def __init__( # noqa CFQ002
self, self,
model_cls: Optional[Type["Model"]] = None, model_cls: Optional[Type["T"]] = None,
filter_clauses: List = None, filter_clauses: List = None,
exclude_clauses: List = None, exclude_clauses: List = None,
select_related: List = None, select_related: List = None,
@ -62,22 +66,6 @@ class QuerySet:
self.order_bys = order_bys or [] self.order_bys = order_bys or []
self.limit_sql_raw = limit_raw_sql self.limit_sql_raw = limit_raw_sql
def __get__(
self,
instance: Optional[Union["QuerySet", "QuerysetProxy"]],
owner: Union[Type["Model"], Type["QuerysetProxy"]],
) -> "QuerySet":
if issubclass(owner, ormar.Model):
if owner.Meta.requires_ref_update:
raise ModelError(
f"Model {owner.get_name()} has not updated "
f"ForwardRefs. \nBefore using the model you "
f"need to call update_forward_refs()."
)
owner = cast(Type["Model"], owner)
return self.__class__(model_cls=owner)
return self.__class__() # pragma: no cover
@property @property
def model_meta(self) -> "ModelMeta": def model_meta(self) -> "ModelMeta":
""" """
@ -91,7 +79,7 @@ class QuerySet:
return self.model_cls.Meta return self.model_cls.Meta
@property @property
def model(self) -> Type["Model"]: def model(self) -> Type["T"]:
""" """
Shortcut to model class set on QuerySet. Shortcut to model class set on QuerySet.
@ -148,8 +136,8 @@ class QuerySet:
) )
async def _prefetch_related_models( async def _prefetch_related_models(
self, models: List[Optional["Model"]], rows: List self, models: List[Optional["T"]], rows: List
) -> List[Optional["Model"]]: ) -> List[Optional["T"]]:
""" """
Performs prefetch query for selected models names. Performs prefetch query for selected models names.
@ -169,7 +157,7 @@ class QuerySet:
) )
return await query.prefetch_related(models=models, rows=rows) # type: ignore return await query.prefetch_related(models=models, rows=rows) # type: ignore
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> List[Optional["T"]]:
""" """
Process database rows and initialize ormar Model from each of the rows. Process database rows and initialize ormar Model from each of the rows.
@ -190,7 +178,7 @@ class QuerySet:
] ]
if result_rows: if result_rows:
return self.model.merge_instances_list(result_rows) # type: ignore return self.model.merge_instances_list(result_rows) # type: ignore
return result_rows return cast(List[Optional["T"]], result_rows)
def _resolve_filter_groups(self, groups: Any) -> List[FilterGroup]: def _resolve_filter_groups(self, groups: Any) -> List[FilterGroup]:
""" """
@ -221,7 +209,7 @@ class QuerySet:
return filter_groups return filter_groups
@staticmethod @staticmethod
def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None: def check_single_result_rows_count(rows: Sequence[Optional["T"]]) -> None:
""" """
Verifies if the result has one and only one row. Verifies if the result has one and only one row.
@ -286,7 +274,7 @@ class QuerySet:
def filter( # noqa: A003 def filter( # noqa: A003
self, *args: Any, _exclude: bool = False, **kwargs: Any self, *args: Any, _exclude: bool = False, **kwargs: Any
) -> "QuerySet": ) -> "QuerySet[T]":
""" """
Allows you to filter by any `Model` attribute/field Allows you to filter by any `Model` attribute/field
as well as to fetch instances, with a filter across an FK relationship. as well as to fetch instances, with a filter across an FK relationship.
@ -337,7 +325,7 @@ class QuerySet:
select_related=select_related, select_related=select_related,
) )
def exclude(self, *args: Any, **kwargs: Any) -> "QuerySet": # noqa: A003 def exclude(self, *args: Any, **kwargs: Any) -> "QuerySet[T]": # noqa: A003
""" """
Works exactly the same as filter and all modifiers (suffixes) are the same, Works exactly the same as filter and all modifiers (suffixes) are the same,
but returns a *not* condition. but returns a *not* condition.
@ -358,7 +346,7 @@ class QuerySet:
""" """
return self.filter(_exclude=True, *args, **kwargs) return self.filter(_exclude=True, *args, **kwargs)
def select_related(self, related: Union[List, str]) -> "QuerySet": def select_related(self, related: Union[List, str]) -> "QuerySet[T]":
""" """
Allows to prefetch related models during the same query. Allows to prefetch related models during the same query.
@ -381,7 +369,33 @@ class QuerySet:
related = sorted(list(set(list(self._select_related) + related))) related = sorted(list(set(list(self._select_related) + related)))
return self.rebuild_self(select_related=related,) return self.rebuild_self(select_related=related,)
def prefetch_related(self, related: Union[List, str]) -> "QuerySet": def select_all(self, follow: bool = False) -> "QuerySet[T]":
"""
By default adds only directly related models.
If follow=True is set it adds also related models of related models.
To not get stuck in an infinite loop as related models also keep a relation
to parent model visited models set is kept.
That way already visited models that are nested are loaded, but the load do not
follow them inside. So Model A -> Model B -> Model C -> Model A -> Model X
will load second Model A but will never follow into Model X.
Nested relations of those kind need to be loaded manually.
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:return: reloaded Model
:rtype: Model
"""
relations = list(self.model.extract_related_names())
if follow:
relations = self.model._iterate_related_models()
return self.rebuild_self(select_related=relations,)
def prefetch_related(self, related: Union[List, str]) -> "QuerySet[T]":
""" """
Allows to prefetch related models during query - but opposite to Allows to prefetch related models during query - but opposite to
`select_related` each subsequent model is fetched in a separate database query. `select_related` each subsequent model is fetched in a separate database query.
@ -407,7 +421,7 @@ class QuerySet:
def fields( def fields(
self, columns: Union[List, str, Set, Dict], _is_exclude: bool = False self, columns: Union[List, str, Set, Dict], _is_exclude: bool = False
) -> "QuerySet": ) -> "QuerySet[T]":
""" """
With `fields()` you can select subset of model columns to limit the data load. With `fields()` you can select subset of model columns to limit the data load.
@ -461,7 +475,7 @@ class QuerySet:
return self.rebuild_self(excludable=excludable,) return self.rebuild_self(excludable=excludable,)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet[T]":
""" """
With `exclude_fields()` you can select subset of model columns that will With `exclude_fields()` you can select subset of model columns that will
be excluded to limit the data load. be excluded to limit the data load.
@ -490,7 +504,7 @@ class QuerySet:
""" """
return self.fields(columns=columns, _is_exclude=True) return self.fields(columns=columns, _is_exclude=True)
def order_by(self, columns: Union[List, str]) -> "QuerySet": def order_by(self, columns: Union[List, str]) -> "QuerySet[T]":
""" """
With `order_by()` you can order the results from database based on your With `order_by()` you can order the results from database based on your
choice of fields. choice of fields.
@ -680,7 +694,7 @@ class QuerySet:
) )
return await self.database.execute(expr) return await self.database.execute(expr)
def paginate(self, page: int, page_size: int = 20) -> "QuerySet": def paginate(self, page: int, page_size: int = 20) -> "QuerySet[T]":
""" """
You can paginate the result which is a combination of offset and limit clauses. You can paginate the result which is a combination of offset and limit clauses.
Limit is set to page size and offset is set to (page-1) * page_size. Limit is set to page size and offset is set to (page-1) * page_size.
@ -699,7 +713,7 @@ class QuerySet:
query_offset = (page - 1) * page_size query_offset = (page - 1) * page_size
return self.rebuild_self(limit_count=limit_count, offset=query_offset,) return self.rebuild_self(limit_count=limit_count, offset=query_offset,)
def limit(self, limit_count: int, limit_raw_sql: bool = None) -> "QuerySet": def limit(self, limit_count: int, limit_raw_sql: bool = None) -> "QuerySet[T]":
""" """
You can limit the results to desired number of parent models. You can limit the results to desired number of parent models.
@ -716,7 +730,7 @@ class QuerySet:
limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql
return self.rebuild_self(limit_count=limit_count, limit_raw_sql=limit_raw_sql,) return self.rebuild_self(limit_count=limit_count, limit_raw_sql=limit_raw_sql,)
def offset(self, offset: int, limit_raw_sql: bool = None) -> "QuerySet": def offset(self, offset: int, limit_raw_sql: bool = None) -> "QuerySet[T]":
""" """
You can also offset the results by desired number of main models. You can also offset the results by desired number of main models.
@ -733,7 +747,7 @@ class QuerySet:
limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql
return self.rebuild_self(offset=offset, limit_raw_sql=limit_raw_sql,) return self.rebuild_self(offset=offset, limit_raw_sql=limit_raw_sql,)
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "T":
""" """
Gets the first row from the db ordered by primary key column ascending. Gets the first row from the db ordered by primary key column ascending.
@ -764,7 +778,7 @@ class QuerySet:
self.check_single_result_rows_count(processed_rows) self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore return processed_rows[0] # type: ignore
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "T":
""" """
Get's the first row from the db meeting the criteria set by kwargs. Get's the first row from the db meeting the criteria set by kwargs.
@ -803,7 +817,7 @@ class QuerySet:
self.check_single_result_rows_count(processed_rows) self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore return processed_rows[0] # type: ignore
async def get_or_create(self, **kwargs: Any) -> "Model": async def get_or_create(self, **kwargs: Any) -> "T":
""" """
Combination of create and get methods. Combination of create and get methods.
@ -821,7 +835,7 @@ class QuerySet:
except NoMatch: except NoMatch:
return await self.create(**kwargs) return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model": async def update_or_create(self, **kwargs: Any) -> "T":
""" """
Updates the model, or in case there is no match in database creates a new one. Updates the model, or in case there is no match in database creates a new one.
@ -838,7 +852,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name]) model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs) return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> List[Optional["T"]]: # noqa: A003
""" """
Returns all rows from a database for given model for set filter options. Returns all rows from a database for given model for set filter options.
@ -862,7 +876,7 @@ class QuerySet:
return result_rows return result_rows
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> "T":
""" """
Creates the model instance, saves it in a database and returns the updates model Creates the model instance, saves it in a database and returns the updates model
(with pk populated if not passed and autoincrement is set). (with pk populated if not passed and autoincrement is set).
@ -905,7 +919,7 @@ class QuerySet:
) )
return instance return instance
async def bulk_create(self, objects: List["Model"]) -> None: async def bulk_create(self, objects: List["T"]) -> None:
""" """
Performs a bulk update in one database session to speed up the process. Performs a bulk update in one database session to speed up the process.
@ -931,7 +945,7 @@ class QuerySet:
objt.set_save_status(True) objt.set_save_status(True)
async def bulk_update( # noqa: CCR001 async def bulk_update( # noqa: CCR001
self, objects: List["Model"], columns: List[str] = None self, objects: List["T"], columns: List[str] = None
) -> None: ) -> None:
""" """
Performs bulk update in one database session to speed up the process. Performs bulk update in one database session to speed up the process.

View File

@ -264,7 +264,7 @@ def get_relationship_alias_model_and_str(
def _process_through_field( def _process_through_field(
related_parts: List, related_parts: List,
relation: Optional[str], relation: Optional[str],
related_field: Type["BaseField"], related_field: "BaseField",
previous_model: Type["Model"], previous_model: Type["Model"],
previous_models: List[Type["Model"]], previous_models: List[Type["Model"]],
) -> Tuple[Type["Model"], Optional[str], bool]: ) -> Tuple[Type["Model"], Optional[str], bool]:
@ -276,7 +276,7 @@ def _process_through_field(
:param relation: relation name :param relation: relation name
:type relation: str :type relation: str
:param related_field: field with relation declaration :param related_field: field with relation declaration
:type related_field: Type["ForeignKeyField"] :type related_field: "ForeignKeyField"
:param previous_model: model from which relation is coming :param previous_model: model from which relation is coming
:type previous_model: Type["Model"] :type previous_model: Type["Model"]
:param previous_models: list of already visited models in relation chain :param previous_models: list of already visited models in relation chain

View File

@ -151,7 +151,7 @@ class AliasManager:
self, self,
source_model: Union[Type["Model"], Type["ModelRow"]], source_model: Union[Type["Model"], Type["ModelRow"]],
relation_str: str, relation_str: str,
relation_field: Type["ForeignKeyField"], relation_field: "ForeignKeyField",
) -> str: ) -> str:
""" """
Given source model and relation string returns the alias for this complex Given source model and relation string returns the alias for this complex
@ -159,7 +159,7 @@ class AliasManager:
field definition. field definition.
:param relation_field: field with direct relation definition :param relation_field: field with direct relation definition
:type relation_field: Type["ForeignKeyField"] :type relation_field: "ForeignKeyField"
:param source_model: model with query starts :param source_model: model with query starts
:type source_model: source Model :type source_model: source Model
:param relation_str: string with relation joins defined :param relation_str: string with relation joins defined

View File

@ -2,28 +2,32 @@ from _weakref import CallableProxyType
from typing import ( # noqa: I100, I201 from typing import ( # noqa: I100, I201
Any, Any,
Dict, Dict,
Generic,
List, List,
MutableSequence, MutableSequence,
Optional, Optional,
Sequence, Sequence,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Type,
TypeVar,
Union, Union,
cast, cast,
) )
import ormar # noqa: I100, I202 import ormar # noqa: I100, I202
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation from ormar.relations import Relation
from ormar.models import Model from ormar.models import Model, T
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
from ormar import RelationType from ormar import RelationType
else:
T = TypeVar("T", bound="Model")
class QuerysetProxy: class QuerysetProxy(Generic[T]):
""" """
Exposes QuerySet methods on relations, but also handles creating and removing Exposes QuerySet methods on relations, but also handles creating and removing
of through Models for m2m relations. of through Models for m2m relations.
@ -33,16 +37,21 @@ class QuerysetProxy:
relation: "Relation" relation: "Relation"
def __init__( def __init__(
self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None self,
relation: "Relation",
to: Type["T"],
type_: "RelationType",
qryset: "QuerySet[T]" = None,
) -> None: ) -> None:
self.relation: Relation = relation self.relation: Relation = relation
self._queryset: Optional["QuerySet"] = qryset self._queryset: Optional["QuerySet[T]"] = qryset
self.type_: "RelationType" = type_ self.type_: "RelationType" = type_
self._owner: Union[CallableProxyType, "Model"] = self.relation.manager.owner self._owner: Union[CallableProxyType, "Model"] = self.relation.manager.owner
self.related_field_name = self._owner.Meta.model_fields[ self.related_field_name = self._owner.Meta.model_fields[
self.relation.field_name self.relation.field_name
].get_related_name() ].get_related_name()
self.related_field = self.relation.to.Meta.model_fields[self.related_field_name] self.to: Type[T] = to
self.related_field = to.Meta.model_fields[self.related_field_name]
self.owner_pk_value = self._owner.pk self.owner_pk_value = self._owner.pk
self.through_model_name = ( self.through_model_name = (
self.related_field.through.get_name() self.related_field.through.get_name()
@ -51,7 +60,7 @@ class QuerysetProxy:
) )
@property @property
def queryset(self) -> "QuerySet": def queryset(self) -> "QuerySet[T]":
""" """
Returns queryset if it's set, AttributeError otherwise. Returns queryset if it's set, AttributeError otherwise.
:return: QuerySet :return: QuerySet
@ -70,7 +79,7 @@ class QuerysetProxy:
""" """
self._queryset = value self._queryset = value
def _assign_child_to_parent(self, child: Optional["Model"]) -> None: def _assign_child_to_parent(self, child: Optional["T"]) -> None:
""" """
Registers child in parents RelationManager. Registers child in parents RelationManager.
@ -82,9 +91,7 @@ class QuerysetProxy:
rel_name = self.relation.field_name rel_name = self.relation.field_name
setattr(owner, rel_name, child) setattr(owner, rel_name, child)
def _register_related( def _register_related(self, child: Union["T", Sequence[Optional["T"]]]) -> None:
self, child: Union["Model", Sequence[Optional["Model"]]]
) -> None:
""" """
Registers child/ children in parents RelationManager. Registers child/ children in parents RelationManager.
@ -96,7 +103,7 @@ class QuerysetProxy:
self._assign_child_to_parent(subchild) self._assign_child_to_parent(subchild)
else: else:
assert isinstance(child, ormar.Model) assert isinstance(child, ormar.Model)
child = cast("Model", child) child = cast("T", child)
self._assign_child_to_parent(child) self._assign_child_to_parent(child)
def _clean_items_on_load(self) -> None: def _clean_items_on_load(self) -> None:
@ -107,7 +114,7 @@ class QuerysetProxy:
for item in self.relation.related_models[:]: for item in self.relation.related_models[:]:
self.relation.remove(item) self.relation.remove(item)
async def create_through_instance(self, child: "Model", **kwargs: Any) -> None: async def create_through_instance(self, child: "T", **kwargs: Any) -> None:
""" """
Crete a through model instance in the database for m2m relations. Crete a through model instance in the database for m2m relations.
@ -129,7 +136,7 @@ class QuerysetProxy:
) )
await model_cls(**final_kwargs).save() await model_cls(**final_kwargs).save()
async def update_through_instance(self, child: "Model", **kwargs: Any) -> None: async def update_through_instance(self, child: "T", **kwargs: Any) -> None:
""" """
Updates a through model instance in the database for m2m relations. Updates a through model instance in the database for m2m relations.
@ -145,7 +152,7 @@ class QuerysetProxy:
through_model = await model_cls.objects.get(**rel_kwargs) through_model = await model_cls.objects.get(**rel_kwargs)
await through_model.update(**kwargs) await through_model.update(**kwargs)
async def delete_through_instance(self, child: "Model") -> None: async def delete_through_instance(self, child: "T") -> None:
""" """
Removes through model instance from the database for m2m relations. Removes through model instance from the database for m2m relations.
@ -254,7 +261,7 @@ class QuerysetProxy:
) )
return await queryset.delete(**kwargs) # type: ignore return await queryset.delete(**kwargs) # type: ignore
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "T":
""" """
Gets the first row from the db ordered by primary key column ascending. Gets the first row from the db ordered by primary key column ascending.
@ -272,7 +279,7 @@ class QuerysetProxy:
self._register_related(first) self._register_related(first)
return first return first
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "T":
""" """
Get's the first row from the db meeting the criteria set by kwargs. Get's the first row from the db meeting the criteria set by kwargs.
@ -296,7 +303,7 @@ class QuerysetProxy:
self._register_related(get) self._register_related(get)
return get return get
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> List[Optional["T"]]: # noqa: A003
""" """
Returns all rows from a database for given model for set filter options. Returns all rows from a database for given model for set filter options.
@ -318,7 +325,7 @@ class QuerysetProxy:
self._register_related(all_items) self._register_related(all_items)
return all_items return all_items
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> "T":
""" """
Creates the model instance, saves it in a database and returns the updates model Creates the model instance, saves it in a database and returns the updates model
(with pk populated if not passed and autoincrement is set). (with pk populated if not passed and autoincrement is set).
@ -375,7 +382,7 @@ class QuerysetProxy:
) )
return len(children) return len(children)
async def get_or_create(self, **kwargs: Any) -> "Model": async def get_or_create(self, **kwargs: Any) -> "T":
""" """
Combination of create and get methods. Combination of create and get methods.
@ -393,7 +400,7 @@ class QuerysetProxy:
except ormar.NoMatch: except ormar.NoMatch:
return await self.create(**kwargs) return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model": async def update_or_create(self, **kwargs: Any) -> "T":
""" """
Updates the model, or in case there is no match in database creates a new one. Updates the model, or in case there is no match in database creates a new one.
@ -412,7 +419,9 @@ class QuerysetProxy:
model = await self.queryset.get(pk=kwargs[pk_name]) model = await self.queryset.get(pk=kwargs[pk_name])
return await model.update(**kwargs) return await model.update(**kwargs)
def filter(self, *args: Any, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 def filter( # noqa: A003, A001
self, *args: Any, **kwargs: Any
) -> "QuerysetProxy[T]":
""" """
Allows you to filter by any `Model` attribute/field Allows you to filter by any `Model` attribute/field
as well as to fetch instances, with a filter across an FK relationship. as well as to fetch instances, with a filter across an FK relationship.
@ -443,9 +452,13 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.filter(*args, **kwargs) queryset = self.queryset.filter(*args, **kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def exclude(self, *args: Any, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001 def exclude(
self, *args: Any, **kwargs: Any
) -> "QuerysetProxy[T]": # noqa: A003, A001
""" """
Works exactly the same as filter and all modifiers (suffixes) are the same, Works exactly the same as filter and all modifiers (suffixes) are the same,
but returns a *not* condition. but returns a *not* condition.
@ -467,9 +480,37 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.exclude(*args, **kwargs) queryset = self.queryset.exclude(*args, **kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def select_related(self, related: Union[List, str]) -> "QuerysetProxy": def select_all(self, follow: bool = False) -> "QuerysetProxy[T]":
"""
By default adds only directly related models.
If follow=True is set it adds also related models of related models.
To not get stuck in an infinite loop as related models also keep a relation
to parent model visited models set is kept.
That way already visited models that are nested are loaded, but the load do not
follow them inside. So Model A -> Model B -> Model C -> Model A -> Model X
will load second Model A but will never follow into Model X.
Nested relations of those kind need to be loaded manually.
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:return: reloaded Model
:rtype: Model
"""
queryset = self.queryset.select_all(follow=follow)
return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def select_related(self, related: Union[List, str]) -> "QuerysetProxy[T]":
""" """
Allows to prefetch related models during the same query. Allows to prefetch related models during the same query.
@ -489,9 +530,11 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.select_related(related) queryset = self.queryset.select_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy": def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy[T]":
""" """
Allows to prefetch related models during query - but opposite to Allows to prefetch related models during query - but opposite to
`select_related` each subsequent model is fetched in a separate database query. `select_related` each subsequent model is fetched in a separate database query.
@ -512,9 +555,11 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.prefetch_related(related) queryset = self.queryset.prefetch_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def paginate(self, page: int, page_size: int = 20) -> "QuerysetProxy": def paginate(self, page: int, page_size: int = 20) -> "QuerysetProxy[T]":
""" """
You can paginate the result which is a combination of offset and limit clauses. You can paginate the result which is a combination of offset and limit clauses.
Limit is set to page size and offset is set to (page-1) * page_size. Limit is set to page size and offset is set to (page-1) * page_size.
@ -529,9 +574,11 @@ class QuerysetProxy:
:rtype: QuerySet :rtype: QuerySet
""" """
queryset = self.queryset.paginate(page=page, page_size=page_size) queryset = self.queryset.paginate(page=page, page_size=page_size)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def limit(self, limit_count: int) -> "QuerysetProxy": def limit(self, limit_count: int) -> "QuerysetProxy[T]":
""" """
You can limit the results to desired number of parent models. You can limit the results to desired number of parent models.
@ -543,9 +590,11 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.limit(limit_count) queryset = self.queryset.limit(limit_count)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def offset(self, offset: int) -> "QuerysetProxy": def offset(self, offset: int) -> "QuerysetProxy[T]":
""" """
You can also offset the results by desired number of main models. You can also offset the results by desired number of main models.
@ -557,9 +606,11 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.offset(offset) queryset = self.queryset.offset(offset)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy[T]":
""" """
With `fields()` you can select subset of model columns to limit the data load. With `fields()` you can select subset of model columns to limit the data load.
@ -605,9 +656,13 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.fields(columns) queryset = self.queryset.fields(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy": def exclude_fields(
self, columns: Union[List, str, Set, Dict]
) -> "QuerysetProxy[T]":
""" """
With `exclude_fields()` you can select subset of model columns that will With `exclude_fields()` you can select subset of model columns that will
be excluded to limit the data load. be excluded to limit the data load.
@ -637,9 +692,11 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.exclude_fields(columns=columns) queryset = self.queryset.exclude_fields(columns=columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy": def order_by(self, columns: Union[List, str]) -> "QuerysetProxy[T]":
""" """
With `order_by()` you can order the results from database based on your With `order_by()` you can order the results from database based on your
choice of fields. choice of fields.
@ -674,4 +731,6 @@ class QuerysetProxy:
:rtype: QuerysetProxy :rtype: QuerysetProxy
""" """
queryset = self.queryset.order_by(columns) queryset = self.queryset.order_by(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset) return self.__class__(
relation=self.relation, type_=self.type_, to=self.to, qryset=queryset
)

View File

@ -1,5 +1,15 @@
from enum import Enum from enum import Enum
from typing import List, Optional, Set, TYPE_CHECKING, Type, Union from typing import (
Generic,
List,
Optional,
Set,
TYPE_CHECKING,
Type,
TypeVar,
Union,
cast,
)
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100
@ -7,7 +17,9 @@ from ormar.relations.relation_proxy import RelationProxy
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.relations import RelationsManager from ormar.relations import RelationsManager
from ormar.models import Model, NewBaseModel from ormar.models import Model, NewBaseModel, T
else:
T = TypeVar("T", bound="Model")
class RelationType(Enum): class RelationType(Enum):
@ -25,7 +37,7 @@ class RelationType(Enum):
THROUGH = 4 THROUGH = 4
class Relation: class Relation(Generic[T]):
""" """
Keeps related Models and handles adding/removing of the children. Keeps related Models and handles adding/removing of the children.
""" """
@ -35,7 +47,7 @@ class Relation:
manager: "RelationsManager", manager: "RelationsManager",
type_: RelationType, type_: RelationType,
field_name: str, field_name: str,
to: Type["Model"], to: Type["T"],
through: Type["Model"] = None, through: Type["Model"] = None,
) -> None: ) -> None:
""" """
@ -59,11 +71,11 @@ class Relation:
self._owner: "Model" = manager.owner self._owner: "Model" = manager.owner
self._type: RelationType = type_ self._type: RelationType = type_
self._to_remove: Set = set() self._to_remove: Set = set()
self.to: Type["Model"] = to self.to: Type["T"] = to
self._through = through self._through = through
self.field_name: str = field_name self.field_name: str = field_name
self.related_models: Optional[Union[RelationProxy, "Model"]] = ( self.related_models: Optional[Union[RelationProxy, "Model"]] = (
RelationProxy(relation=self, type_=type_, field_name=field_name) RelationProxy(relation=self, type_=type_, to=to, field_name=field_name)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
) )
@ -73,7 +85,8 @@ class Relation:
self.related_models = None self.related_models = None
self._owner.__dict__[self.field_name] = None self._owner.__dict__[self.field_name] = None
elif self.related_models is not None: elif self.related_models is not None:
self.related_models._clear() related_models = cast("RelationProxy", self.related_models)
related_models._clear()
self._owner.__dict__[self.field_name] = None self._owner.__dict__[self.field_name] = None
@property @property
@ -94,6 +107,7 @@ class Relation:
self.related_models = RelationProxy( self.related_models = RelationProxy(
relation=self, relation=self,
type_=self._type, type_=self._type,
to=self.to,
field_name=self.field_name, field_name=self.field_name,
data_=cleaned_data, data_=cleaned_data,
) )

View File

@ -16,7 +16,7 @@ class RelationsManager:
def __init__( def __init__(
self, self,
related_fields: List[Type["ForeignKeyField"]] = None, related_fields: List["ForeignKeyField"] = None,
owner: Optional["Model"] = None, owner: Optional["Model"] = None,
) -> None: ) -> None:
self.owner = proxy(owner) self.owner = proxy(owner)
@ -57,7 +57,7 @@ class RelationsManager:
return None # pragma nocover return None # pragma nocover
@staticmethod @staticmethod
def add(parent: "Model", child: "Model", field: Type["ForeignKeyField"],) -> None: def add(parent: "Model", child: "Model", field: "ForeignKeyField",) -> None:
""" """
Adds relation on both sides -> meaning on both child and parent models. Adds relation on both sides -> meaning on both child and parent models.
One side of the relation is always weakref proxy to avoid circular refs. One side of the relation is always weakref proxy to avoid circular refs.
@ -138,12 +138,12 @@ class RelationsManager:
return relation return relation
return None return None
def _get_relation_type(self, field: Type["BaseField"]) -> RelationType: def _get_relation_type(self, field: "BaseField") -> RelationType:
""" """
Returns type of the relation declared on a field. Returns type of the relation declared on a field.
:param field: field with relation declaration :param field: field with relation declaration
:type field: Type[BaseField] :type field: BaseField
:return: type of the relation defined on field :return: type of the relation defined on field
:rtype: RelationType :rtype: RelationType
""" """
@ -153,13 +153,13 @@ class RelationsManager:
return RelationType.THROUGH return RelationType.THROUGH
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
def _add_relation(self, field: Type["BaseField"]) -> None: def _add_relation(self, field: "BaseField") -> None:
""" """
Registers relation in the manager. Registers relation in the manager.
Adds Relation instance under field.name. Adds Relation instance under field.name.
:param field: field with relation declaration :param field: field with relation declaration
:type field: Type[BaseField] :type field: BaseField
""" """
self._relations[field.name] = Relation( self._relations[field.name] = Relation(
manager=self, manager=self,

View File

@ -1,4 +1,4 @@
from typing import Any, Optional, TYPE_CHECKING from typing import Any, Generic, Optional, TYPE_CHECKING, Type, TypeVar
import ormar import ormar
from ormar.exceptions import NoMatch, RelationshipInstanceError from ormar.exceptions import NoMatch, RelationshipInstanceError
@ -6,11 +6,14 @@ from ormar.relations.querysetproxy import QuerysetProxy
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model, RelationType from ormar import Model, RelationType
from ormar.models import T
from ormar.relations import Relation from ormar.relations import Relation
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
else:
T = TypeVar("T", bound="Model")
class RelationProxy(list): class RelationProxy(Generic[T], list):
""" """
Proxy of the Relation that is a list with special methods. Proxy of the Relation that is a list with special methods.
""" """
@ -19,16 +22,17 @@ class RelationProxy(list):
self, self,
relation: "Relation", relation: "Relation",
type_: "RelationType", type_: "RelationType",
to: Type["T"],
field_name: str, field_name: str,
data_: Any = None, data_: Any = None,
) -> None: ) -> None:
super().__init__(data_ or ()) super().__init__(data_ or ())
self.relation: "Relation" = relation self.relation: "Relation[T]" = relation
self.type_: "RelationType" = type_ self.type_: "RelationType" = type_
self.field_name = field_name self.field_name = field_name
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy: QuerysetProxy = QuerysetProxy( self.queryset_proxy: QuerysetProxy[T] = QuerysetProxy[T](
relation=self.relation, type_=type_ relation=self.relation, to=to, type_=type_
) )
self._related_field_name: Optional[str] = None self._related_field_name: Optional[str] = None
@ -48,6 +52,9 @@ class RelationProxy(list):
return self._related_field_name return self._related_field_name
def __getitem__(self, item: Any) -> "T": # type: ignore
return super().__getitem__(item)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
""" """
Since some QuerySetProxy methods overwrite builtin list methods we Since some QuerySetProxy methods overwrite builtin list methods we
@ -107,7 +114,7 @@ class RelationProxy(list):
"You cannot query relationships from unsaved model." "You cannot query relationships from unsaved model."
) )
def _set_queryset(self) -> "QuerySet": def _set_queryset(self) -> "QuerySet[T]":
""" """
Creates new QuerySet with relation model and pre filters it with currents Creates new QuerySet with relation model and pre filters it with currents
parent model primary key, so all queries by definition are already related parent model primary key, so all queries by definition are already related
@ -131,7 +138,7 @@ class RelationProxy(list):
return queryset return queryset
async def remove( # type: ignore async def remove( # type: ignore
self, item: "Model", keep_reversed: bool = True self, item: "T", keep_reversed: bool = True
) -> None: ) -> None:
""" """
Removes the related from relation with parent. Removes the related from relation with parent.
@ -182,7 +189,7 @@ class RelationProxy(list):
relation_name=self.field_name, relation_name=self.field_name,
) )
async def add(self, item: "Model", **kwargs: Any) -> None: async def add(self, item: "T", **kwargs: Any) -> None:
""" """
Adds child model to relation. Adds child model to relation.

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple, Type from typing import TYPE_CHECKING, Tuple
from weakref import proxy from weakref import proxy
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
@ -8,7 +8,7 @@ if TYPE_CHECKING: # pragma no cover
def get_relations_sides_and_names( def get_relations_sides_and_names(
to_field: Type[ForeignKeyField], parent: "Model", child: "Model", to_field: ForeignKeyField, parent: "Model", child: "Model",
) -> Tuple["Model", "Model", str, str]: ) -> Tuple["Model", "Model", str, str]:
""" """
Determines the names of child and parent relations names, as well as Determines the names of child and parent relations names, as well as

View File

@ -193,8 +193,8 @@ def test_field_redefining_in_concrete_models():
created_date: str = ormar.String(max_length=200, name="creation_date") created_date: str = ormar.String(max_length=200, name="creation_date")
changed_field = RedefinedField.Meta.model_fields["created_date"] changed_field = RedefinedField.Meta.model_fields["created_date"]
assert changed_field.default is None assert changed_field.ormar_default is None
assert changed_field.alias == "creation_date" assert changed_field.get_alias() == "creation_date"
assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns) assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns)
assert isinstance( assert isinstance(
RedefinedField.Meta.table.columns["creation_date"].type, sa.sql.sqltypes.String, RedefinedField.Meta.table.columns["creation_date"].type, sa.sql.sqltypes.String,

View File

@ -64,8 +64,10 @@ def test_field_redefining():
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
created_date: datetime.datetime = ormar.DateTime(name="creation_date") created_date: datetime.datetime = ormar.DateTime(name="creation_date")
assert RedefinedField.Meta.model_fields["created_date"].default is None assert RedefinedField.Meta.model_fields["created_date"].ormar_default is None
assert RedefinedField.Meta.model_fields["created_date"].alias == "creation_date" assert (
RedefinedField.Meta.model_fields["created_date"].get_alias() == "creation_date"
)
assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns) assert any(x.name == "creation_date" for x in RedefinedField.Meta.table.columns)
@ -87,8 +89,10 @@ def test_field_redefining_in_second_raises_error():
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
created_date: str = ormar.String(max_length=200, name="creation_date") created_date: str = ormar.String(max_length=200, name="creation_date")
assert RedefinedField2.Meta.model_fields["created_date"].default is None assert RedefinedField2.Meta.model_fields["created_date"].ormar_default is None
assert RedefinedField2.Meta.model_fields["created_date"].alias == "creation_date" assert (
RedefinedField2.Meta.model_fields["created_date"].get_alias() == "creation_date"
)
assert any(x.name == "creation_date" for x in RedefinedField2.Meta.table.columns) assert any(x.name == "creation_date" for x in RedefinedField2.Meta.table.columns)
assert isinstance( assert isinstance(
RedefinedField2.Meta.table.columns["creation_date"].type, RedefinedField2.Meta.table.columns["creation_date"].type,

View File

@ -86,6 +86,11 @@ async def test_load_all_fk_rel():
assert hq.companies[0].name == "Banzai" assert hq.companies[0].name == "Banzai"
assert hq.companies[0].founded == 1988 assert hq.companies[0].founded == 1988
hq2 = await HQ.objects.select_all().get(name="Main")
assert hq2.companies[0] == company
assert hq2.companies[0].name == "Banzai"
assert hq2.companies[0].founded == 1988
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_all_many_to_many(): async def test_load_all_many_to_many():
@ -106,6 +111,12 @@ async def test_load_all_many_to_many():
assert hq.nicks[1] == nick2 assert hq.nicks[1] == nick2
assert hq.nicks[1].name == "Bazinga20" assert hq.nicks[1].name == "Bazinga20"
hq2 = await HQ.objects.select_all().get(name="Main")
assert hq2.nicks[0] == nick1
assert hq2.nicks[0].name == "BazingaO"
assert hq2.nicks[1] == nick2
assert hq2.nicks[1].name == "Bazinga20"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_all_with_order(): async def test_load_all_with_order():
@ -130,6 +141,16 @@ async def test_load_all_with_order():
assert hq.nicks[0] == nick1 assert hq.nicks[0] == nick1
assert hq.nicks[1] == nick2 assert hq.nicks[1] == nick2
hq2 = (
await HQ.objects.select_all().order_by("-nicks__name").get(name="Main")
)
assert hq2.nicks[0] == nick2
assert hq2.nicks[1] == nick1
hq3 = await HQ.objects.select_all().get(name="Main")
assert hq3.nicks[0] == nick1
assert hq3.nicks[1] == nick2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_loading_reversed_relation(): async def test_loading_reversed_relation():
@ -143,6 +164,9 @@ async def test_loading_reversed_relation():
assert company.hq == hq assert company.hq == hq
company2 = await Company.objects.select_all().get(name="Banzai")
assert company2.hq == hq
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_loading_nested(): async def test_loading_nested():
@ -174,11 +198,43 @@ async def test_loading_nested():
assert hq.nicks[1].level.name == "Low" assert hq.nicks[1].level.name == "Low"
assert hq.nicks[1].level.language.name == "English" assert hq.nicks[1].level.language.name == "English"
hq2 = await HQ.objects.select_all(follow=True).get(name="Main")
assert hq2.nicks[0] == nick1
assert hq2.nicks[0].name == "BazingaO"
assert hq2.nicks[0].level.name == "High"
assert hq2.nicks[0].level.language.name == "English"
assert hq2.nicks[1] == nick2
assert hq2.nicks[1].name == "Bazinga20"
assert hq2.nicks[1].level.name == "Low"
assert hq2.nicks[1].level.language.name == "English"
hq5 = await HQ.objects.select_all().get(name="Main")
assert len(hq5.nicks) == 2
await hq5.nicks.select_all(follow=True).all()
assert hq5.nicks[0] == nick1
assert hq5.nicks[0].name == "BazingaO"
assert hq5.nicks[0].level.name == "High"
assert hq5.nicks[0].level.language.name == "English"
assert hq5.nicks[1] == nick2
assert hq5.nicks[1].name == "Bazinga20"
assert hq5.nicks[1].level.name == "Low"
assert hq5.nicks[1].level.language.name == "English"
await hq.load_all(follow=True, exclude="nicks__level__language") await hq.load_all(follow=True, exclude="nicks__level__language")
assert len(hq.nicks) == 2 assert len(hq.nicks) == 2
assert hq.nicks[0].level.language is None assert hq.nicks[0].level.language is None
assert hq.nicks[1].level.language is None assert hq.nicks[1].level.language is None
hq3 = (
await HQ.objects.select_all(follow=True)
.exclude_fields("nicks__level__language")
.get(name="Main")
)
assert len(hq3.nicks) == 2
assert hq3.nicks[0].level.language is None
assert hq3.nicks[1].level.language is None
await hq.load_all(follow=True, exclude="nicks__level__language__level") await hq.load_all(follow=True, exclude="nicks__level__language__level")
assert len(hq.nicks) == 2 assert len(hq.nicks) == 2
assert hq.nicks[0].level.language is not None assert hq.nicks[0].level.language is not None

View File

@ -29,7 +29,7 @@ class ExampleModel(Model):
test_string: str = ormar.String(max_length=250) test_string: str = ormar.String(max_length=250)
test_text: str = ormar.Text(default="") test_text: str = ormar.Text(default="")
test_bool: bool = ormar.Boolean(nullable=False) test_bool: bool = ormar.Boolean(nullable=False)
test_float: ormar.Float() = None # type: ignore test_float = ormar.Float(nullable=True)
test_datetime = ormar.DateTime(default=datetime.datetime.now) test_datetime = ormar.DateTime(default=datetime.datetime.now)
test_date = ormar.Date(default=datetime.date.today) test_date = ormar.Date(default=datetime.date.today)
test_time = ormar.Time(default=datetime.time) test_time = ormar.Time(default=datetime.time)

View File

@ -120,9 +120,9 @@ async def create_test_database():
def test_model_class(): def test_model_class():
assert list(User.Meta.model_fields.keys()) == ["id", "name"] assert list(User.Meta.model_fields.keys()) == ["id", "name"]
assert issubclass(User.Meta.model_fields["id"], pydantic.fields.FieldInfo) assert issubclass(User.Meta.model_fields["id"].__class__, pydantic.fields.FieldInfo)
assert User.Meta.model_fields["id"].primary_key is True assert User.Meta.model_fields["id"].primary_key is True
assert issubclass(User.Meta.model_fields["name"], pydantic.fields.FieldInfo) assert isinstance(User.Meta.model_fields["name"], pydantic.fields.FieldInfo)
assert User.Meta.model_fields["name"].max_length == 100 assert User.Meta.model_fields["name"].max_length == 100
assert isinstance(User.Meta.table, sqlalchemy.Table) assert isinstance(User.Meta.table, sqlalchemy.Table)

View File

@ -50,7 +50,7 @@ class AliasTest(ormar.Model):
id: int = ormar.Integer(name="alias_id", primary_key=True) id: int = ormar.Integer(name="alias_id", primary_key=True)
name: str = ormar.String(name="alias_name", max_length=100) name: str = ormar.String(name="alias_name", max_length=100)
nested: str = ormar.ForeignKey(AliasNested, name="nested_alias") nested = ormar.ForeignKey(AliasNested, name="nested_alias")
class Toy(ormar.Model): class Toy(ormar.Model):

101
tests/test_types.py Normal file
View File

@ -0,0 +1,101 @@
from typing import Any, Optional, TYPE_CHECKING
import databases
import pytest
import sqlalchemy
import ormar
from ormar.relations.querysetproxy import QuerysetProxy
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class Publisher(ormar.Model):
class Meta(BaseMeta):
tablename = "publishers"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Author(ormar.Model):
class Meta(BaseMeta):
tablename = "authors"
order_by = ["-name"]
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
publishers = ormar.ManyToMany(Publisher)
class Book(ormar.Model):
class Meta(BaseMeta):
tablename = "books"
order_by = ["year", "-ranking"]
id: int = ormar.Integer(primary_key=True)
author = ormar.ForeignKey(Author)
title: str = ormar.String(max_length=100)
year: int = ormar.Integer(nullable=True)
ranking: int = ormar.Integer(nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def assert_type(book: Book):
print(book)
@pytest.mark.asyncio
async def test_types() -> None:
async with database:
query = Book.objects
publisher = await Publisher(name="Test publisher").save()
author = await Author.objects.create(name="Test Author")
await author.publishers.add(publisher)
author2 = await Author.objects.select_related("publishers").get()
publishers = author2.publishers
publisher2 = await Publisher.objects.select_related("authors").get()
authors = publisher2.authors
assert authors[0] == author
for author in authors:
pass
# if TYPE_CHECKING: # pragma: no cover
# reveal_type(author) # iter of relation proxy
book = await Book.objects.create(title="Test", author=author)
book2 = await Book.objects.select_related("author").get()
books = await Book.objects.select_related("author").all()
author_books = await author.books.all()
assert book.author.name == "Test Author"
assert book2.author.name == "Test Author"
# if TYPE_CHECKING: # pragma: no cover
# reveal_type(publisher) # model method
# reveal_type(publishers) # many to many
# reveal_type(publishers[0]) # item in m2m list
# # getting relation without __getattribute__
# reveal_type(authors) # reverse many to many # TODO: wrong
# reveal_type(book2) # queryset get
# reveal_type(books) # queryset all
# reveal_type(book) # queryset - create
# reveal_type(query) # queryset itself
# reveal_type(book.author) # fk
# reveal_type(author.books) # reverse fk relation proxy # TODO: wrong
# reveal_type(author) # another test for queryset get different model
# reveal_type(book.author.name) # field on related model
# reveal_type(author_books) # querysetproxy result for fk # TODO: wrong
# reveal_type(author_books[0]) # item in qs proxy for fk # TODO: wrong
assert_type(book)