added mypy checks and some typehint changes to conform

This commit is contained in:
collerek
2020-09-29 14:05:08 +02:00
parent 6d56ea5e30
commit 3caa87057e
23 changed files with 274 additions and 202 deletions

BIN
.coverage

Binary file not shown.

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
p38venv p38venv
.idea .idea
.pytest_cache .pytest_cache
.mypy_cache
*.pyc *.pyc
*.log *.log
test.db test.db

5
mypy.ini Normal file
View File

@ -0,0 +1,5 @@
[mypy]
python_version = 3.8
[mypy-sqlalchemy.*]
ignore_missing_imports = True

View File

@ -1,7 +1,8 @@
from typing import Any, List, Optional, TYPE_CHECKING, Union from typing import Any, List, Optional, TYPE_CHECKING, Union, Type
import sqlalchemy import sqlalchemy
from pydantic import Field, typing from pydantic import Field, typing
from pydantic.fields import FieldInfo
from ormar import ModelDefinitionError # noqa I101 from ormar import ModelDefinitionError # noqa I101
@ -15,6 +16,7 @@ class BaseField:
column_type: sqlalchemy.Column column_type: sqlalchemy.Column
constraints: List = [] constraints: List = []
name: str
primary_key: bool primary_key: bool
autoincrement: bool autoincrement: bool
@ -24,12 +26,14 @@ class BaseField:
pydantic_only: bool pydantic_only: bool
virtual: bool = False virtual: bool = False
choices: typing.Sequence choices: typing.Sequence
to: Type["Model"]
through: Type["Model"]
default: Any default: Any
server_default: Any server_default: Any
@classmethod @classmethod
def default_value(cls) -> Optional[Field]: def default_value(cls) -> Optional[FieldInfo]:
if cls.is_auto_primary_key(): if cls.is_auto_primary_key():
return Field(default=None) return Field(default=None)
if cls.has_default(): if cls.has_default():

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Generator
import sqlalchemy import sqlalchemy
@ -7,7 +7,7 @@ from ormar.exceptions import RelationshipInstanceError
from ormar.fields.base import BaseField from ormar.fields.base import BaseField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model, NewBaseModel
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
@ -32,7 +32,7 @@ def ForeignKey( # noqa CFQ002
virtual: bool = False, virtual: bool = False,
onupdate: str = None, onupdate: str = None,
ondelete: str = None, ondelete: str = None,
) -> Type[object]: ) -> Type["ForeignKeyField"]:
fk_string = to.Meta.tablename + "." + to.Meta.pkname fk_string = to.Meta.tablename + "." + to.Meta.pkname
to_field = to.__fields__[to.Meta.pkname] to_field = to.__fields__[to.Meta.pkname]
namespace = dict( namespace = dict(
@ -65,7 +65,7 @@ class ForeignKeyField(BaseField):
virtual: bool virtual: bool
@classmethod @classmethod
def __get_validators__(cls) -> Callable: def __get_validators__(cls) -> Generator:
yield cls.validate yield cls.validate
@classmethod @classmethod
@ -75,8 +75,8 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(
cls, value: List, child: "Model", to_register: bool cls, value: List, child: "Model", to_register: bool
) -> Union["Model", List["Model"]]: ) -> List["Model"]:
return [cls.expand_relationship(val, child, to_register) for val in value] return [cls.expand_relationship(val, child, to_register) for val in value] # type: ignore
@classmethod @classmethod
def _register_existing_model( def _register_existing_model(
@ -120,7 +120,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def expand_relationship( def expand_relationship(
cls, value: Any, child: "Model", to_register: bool = True cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if value is None: if value is None:
return None if not cls.virtual else [] return None if not cls.virtual else []
@ -131,7 +131,7 @@ class ForeignKeyField(BaseField):
"list": cls._extract_model_from_sequence, "list": cls._extract_model_from_sequence,
} }
model = constructors.get( model = constructors.get( # type: ignore
value.__class__.__name__, cls._construct_model_from_pk value.__class__.__name__, cls._construct_model_from_pk
)(value, child, to_register) )(value, child, to_register)
return model return model

View File

@ -15,7 +15,7 @@ def ManyToMany(
unique: bool = False, unique: bool = False,
related_name: str = None, related_name: str = None,
virtual: bool = False, virtual: bool = False,
) -> Type[object]: ) -> Type["ManyToManyField"]:
to_field = to.__fields__[to.Meta.pkname] to_field = to.__fields__[to.Meta.pkname]
namespace = dict( namespace = dict(
to=to, to=to,

View File

@ -18,10 +18,10 @@ def is_field_nullable(
class ModelFieldFactory: class ModelFieldFactory:
_bases = BaseField _bases: Any = BaseField
_type = None _type: Any = None
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore
cls.validate(**kwargs) cls.validate(**kwargs)
default = kwargs.pop("default", None) default = kwargs.pop("default", None)
@ -58,7 +58,7 @@ class String(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField) _bases = (pydantic.ConstrainedStr, BaseField)
_type = str _type = str
def __new__( # noqa CFQ002 def __new__( # type: ignore # noqa CFQ002
cls, cls,
*, *,
allow_blank: bool = False, allow_blank: bool = False,
@ -68,7 +68,7 @@ class String(ModelFieldFactory):
curtail_length: int = None, curtail_length: int = None,
regex: str = None, regex: str = None,
**kwargs: Any **kwargs: Any
) -> Type[str]: ) -> Type[BaseField]: # type: ignore
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{
@ -96,14 +96,14 @@ class Integer(ModelFieldFactory):
_bases = (pydantic.ConstrainedInt, BaseField) _bases = (pydantic.ConstrainedInt, BaseField)
_type = int _type = int
def __new__( def __new__( # type: ignore
cls, cls,
*, *,
minimum: int = None, minimum: int = None,
maximum: int = None, maximum: int = None,
multiple_of: int = None, multiple_of: int = None,
**kwargs: Any **kwargs: Any
) -> Type[int]: ) -> Type[BaseField]:
autoincrement = kwargs.pop("autoincrement", None) autoincrement = kwargs.pop("autoincrement", None)
autoincrement = ( autoincrement = (
autoincrement autoincrement
@ -131,9 +131,9 @@ class Text(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField) _bases = (pydantic.ConstrainedStr, BaseField)
_type = str _type = str
def __new__( def __new__( # type: ignore
cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
) -> Type[str]: ) -> Type[BaseField]:
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{
@ -153,14 +153,14 @@ class Float(ModelFieldFactory):
_bases = (pydantic.ConstrainedFloat, BaseField) _bases = (pydantic.ConstrainedFloat, BaseField)
_type = float _type = float
def __new__( def __new__( # type: ignore
cls, cls,
*, *,
minimum: float = None, minimum: float = None,
maximum: float = None, maximum: float = None,
multiple_of: int = None, multiple_of: int = None,
**kwargs: Any **kwargs: Any
) -> Type[int]: ) -> Type[BaseField]:
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{
@ -236,7 +236,7 @@ class Decimal(ModelFieldFactory):
_bases = (pydantic.ConstrainedDecimal, BaseField) _bases = (pydantic.ConstrainedDecimal, BaseField)
_type = decimal.Decimal _type = decimal.Decimal
def __new__( # noqa CFQ002 def __new__( # type: ignore # noqa CFQ002
cls, cls,
*, *,
minimum: float = None, minimum: float = None,
@ -247,7 +247,7 @@ class Decimal(ModelFieldFactory):
max_digits: int = None, max_digits: int = None,
decimal_places: int = None, decimal_places: int = None,
**kwargs: Any **kwargs: Any
) -> Type[decimal.Decimal]: ) -> Type[BaseField]:
kwargs = { kwargs = {
**kwargs, **kwargs,
**{ **{

View File

@ -28,15 +28,19 @@ class ModelMeta:
database: databases.Database database: databases.Database
columns: List[sqlalchemy.Column] columns: List[sqlalchemy.Column]
pkname: str pkname: str
model_fields: Dict[str, Union[BaseField, ForeignKey]] model_fields: Dict[
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
]
alias_manager: AliasManager alias_manager: AliasManager
def register_relation_on_build(table_name: str, field: ForeignKey) -> None: def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
alias_manager.add_relation_type(field.to.Meta.tablename, table_name) alias_manager.add_relation_type(field.to.Meta.tablename, table_name)
def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None: def register_many_to_many_relation_on_build(
table_name: str, field: Type[ManyToManyField]
) -> None:
alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
alias_manager.add_relation_type( alias_manager.add_relation_type(
field.through.Meta.tablename, field.to.Meta.tablename field.through.Meta.tablename, field.to.Meta.tablename
@ -106,7 +110,7 @@ def create_pydantic_field(
) -> None: ) -> None:
model_field.through.__fields__[field_name] = ModelField( model_field.through.__fields__[field_name] = ModelField(
name=field_name, name=field_name,
type_=Optional[model], type_=model,
model_config=model.__config__, model_config=model.__config__,
required=False, required=False,
class_validators={}, class_validators={},
@ -130,7 +134,7 @@ def create_and_append_m2m_fk(
def check_pk_column_validity( def check_pk_column_validity(
field_name: str, field: BaseField, pkname: str field_name: str, field: BaseField, pkname: Optional[str]
) -> Optional[str]: ) -> Optional[str]:
if pkname is not None: if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.") raise ModelDefinitionError("Only one primary key column is allowed.")
@ -218,6 +222,7 @@ def populate_meta_tablename_columns_and_pk(
) -> Type["Model"]: ) -> Type["Model"]:
tablename = name.lower() + "s" tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename or tablename new_model.Meta.tablename = new_model.Meta.tablename or tablename
pkname: Optional[str]
if hasattr(new_model.Meta, "columns"): if hasattr(new_model.Meta, "columns"):
columns = new_model.Meta.table.columns columns = new_model.Meta.table.columns
@ -226,12 +231,13 @@ def populate_meta_tablename_columns_and_pk(
pkname, columns = sqlalchemy_columns_from_model_fields( pkname, columns = sqlalchemy_columns_from_model_fields(
new_model.Meta.model_fields, new_model.Meta.tablename new_model.Meta.model_fields, new_model.Meta.tablename
) )
if pkname is None:
raise ModelDefinitionError("Table has to have a primary key.")
new_model.Meta.columns = columns new_model.Meta.columns = columns
new_model.Meta.pkname = pkname new_model.Meta.pkname = pkname
if not new_model.Meta.pkname:
raise ModelDefinitionError("Table has to have a primary key.")
return new_model return new_model
@ -253,8 +259,8 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
return Config return Config
def check_if_field_has_choices(field: BaseField) -> bool: def check_if_field_has_choices(field: Type[BaseField]) -> bool:
return hasattr(field, "choices") and field.choices return hasattr(field, "choices") and bool(field.choices)
def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool: def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool:
@ -287,7 +293,7 @@ def populate_choices_validators( # noqa CCR001
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: def __new__(mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict) -> "ModelMetaclass": # type: ignore
attrs["Config"] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name attrs["__name__"] = name
attrs = extract_annotations_and_default_vals(attrs, bases) attrs = extract_annotations_and_default_vals(attrs, bases)
@ -306,7 +312,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
field_name = new_model.Meta.pkname field_name = new_model.Meta.pkname
field = Integer(name=field_name, primary_key=True) field = Integer(name=field_name, primary_key=True)
attrs["__annotations__"][field_name] = field attrs["__annotations__"][field_name] = field
populate_default_pydantic_field_value(field, field_name, attrs) populate_default_pydantic_field_value(field, field_name, attrs) # type: ignore
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs

View File

@ -1,5 +1,5 @@
import itertools import itertools
from typing import Any, List, Tuple, Union from typing import Any, List, Dict, Optional
import sqlalchemy import sqlalchemy
from databases.backends.postgres import Record from databases.backends.postgres import Record
@ -9,8 +9,8 @@ from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
def group_related_list(list_: List) -> dict: def group_related_list(list_: List) -> Dict:
test_dict = dict() test_dict: Dict[str, Any] = dict()
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped: for key, group in grouped:
group_list = list(group) group_list = list(group)
@ -34,9 +34,9 @@ class Model(NewBaseModel):
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
previous_table: str = None, previous_table: str = None,
) -> Union["Model", Tuple["Model", dict]]: ) -> Optional["Model"]:
item = {} item: Dict[str, Any] = {}
select_related = select_related or [] select_related = select_related or []
related_models = related_models or [] related_models = related_models or []
if select_related: if select_related:
@ -52,9 +52,12 @@ class Model(NewBaseModel):
previous_table previous_table
].through.Meta.tablename ].through.Meta.tablename
if previous_table:
table_prefix = cls.Meta.alias_manager.resolve_relation_join( table_prefix = cls.Meta.alias_manager.resolve_relation_join(
previous_table, cls.Meta.table.name previous_table, cls.Meta.table.name
) )
else:
table_prefix = ''
previous_table = cls.Meta.table.name previous_table = cls.Meta.table.name
item = cls.populate_nested_models_from_row( item = cls.populate_nested_models_from_row(
@ -106,7 +109,7 @@ class Model(NewBaseModel):
async def save(self) -> "Model": async def save(self) -> "Model":
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement: if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
self_fields.pop(self.Meta.pkname, None) self_fields.pop(self.Meta.pkname, None)
self_fields = self.objects._populate_default_values(self_fields) self_fields = self.objects._populate_default_values(self_fields)
expr = self.Meta.table.insert() expr = self.Meta.table.insert()
@ -138,5 +141,7 @@ class Model(NewBaseModel):
async def load(self) -> "Model": async def load(self) -> "Model":
expr = self.Meta.table.select().where(self.pk_column == self.pk) expr = self.Meta.table.select().where(self.pk_column == self.pk)
row = await self.Meta.database.fetch_one(expr) row = await self.Meta.database.fetch_one(expr)
if not row: # pragma nocover
raise ValueError('Instance was deleted from database and cannot be refreshed')
self.from_dict(dict(row)) self.from_dict(dict(row))
return self return self

View File

@ -1,5 +1,5 @@
import inspect import inspect
from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union, Dict
import ormar import ormar
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
@ -9,6 +9,7 @@ from ormar.models.metaclass import ModelMeta
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.models import NewBaseModel
Field = TypeVar("Field", bound=BaseField) Field = TypeVar("Field", bound=BaseField)
@ -17,10 +18,10 @@ class ModelTableProxy:
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
Meta: ModelMeta Meta: ModelMeta
def dict(): # noqa A003 def dict(self): # noqa A003
raise NotImplementedError # pragma no cover raise NotImplementedError # pragma no cover
def _extract_own_model_fields(self) -> dict: def _extract_own_model_fields(self) -> Dict:
related_names = self.extract_related_names() related_names = self.extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names} self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields return self_fields
@ -34,7 +35,7 @@ class ModelTableProxy:
return self_fields return self_fields
@classmethod @classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict: def substitute_models_with_pks(cls, model_dict: Dict) -> Dict:
for field in cls.extract_related_names(): for field in cls.extract_related_names():
field_value = model_dict.get(field, None) field_value = model_dict.get(field, None)
if field_value is not None: if field_value is not None:
@ -80,7 +81,7 @@ class ModelTableProxy:
related_names.add(name) related_names.add(name)
return related_names return related_names
def _extract_model_db_fields(self) -> dict: def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields() self_fields = self._extract_own_model_fields()
self_fields = { self_fields = {
k: v for k, v in self_fields.items() if k in self.Meta.table.columns k: v for k, v in self_fields.items() if k in self.Meta.table.columns
@ -92,7 +93,9 @@ class ModelTableProxy:
return self_fields return self_fields
@staticmethod @staticmethod
def resolve_relation_name(item: "Model", related: "Model") -> Optional[str]: def resolve_relation_name(
item: Union["NewBaseModel", Type["NewBaseModel"]], related: Union["NewBaseModel", Type["NewBaseModel"]]
) -> str:
for name, field in item.Meta.model_fields.items(): for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField): if issubclass(field, ForeignKeyField):
# fastapi is creating clones of response model # fastapi is creating clones of response model
@ -100,11 +103,14 @@ class ModelTableProxy:
# so we need to compare Meta too as this one is copied as is # so we need to compare Meta too as this one is copied as is
if field.to == related.__class__ or field.to.Meta == related.Meta: if field.to == related.__class__ or field.to.Meta == related.Meta:
return name return name
raise ValueError(
f"No relation between {item.get_name()} and {related.get_name()}"
) # pragma nocover
@staticmethod @staticmethod
def resolve_relation_field( def resolve_relation_field(
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
) -> Type[Field]: ) -> Union[Type[BaseField], Type[ForeignKeyField]]:
name = ModelTableProxy.resolve_relation_name(item, related) name = ModelTableProxy.resolve_relation_name(item, related)
to_field = item.Meta.model_fields.get(name) to_field = item.Meta.model_fields.get(name)
if not to_field: # pragma no cover if not to_field: # pragma no cover
@ -116,7 +122,7 @@ class ModelTableProxy:
@classmethod @classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = [] merged_rows: List["Model"] = []
for index, model in enumerate(result_rows): for index, model in enumerate(result_rows):
if index > 0 and model.pk == merged_rows[-1].pk: if index > 0 and model.pk == merged_rows[-1].pk:
merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])

View File

@ -3,13 +3,13 @@ import uuid
from typing import ( from typing import (
AbstractSet, AbstractSet,
Any, Any,
Callable,
Dict, Dict,
List, List,
Mapping, Mapping,
Optional, Optional,
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar,
Union, Union,
) )
@ -39,7 +39,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
__slots__ = ("_orm_id", "_orm_saved", "_orm") __slots__ = ("_orm_id", "_orm_saved", "_orm")
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]] __model_fields__: Dict[str, Type[BaseField]]
__table__: sqlalchemy.Table __table__: sqlalchemy.Table
__fields__: Dict[str, pydantic.fields.ModelField] __fields__: Dict[str, pydantic.fields.ModelField]
__pydantic_model__: Type[BaseModel] __pydantic_model__: Type[BaseModel]
@ -84,7 +84,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
for k, v in kwargs.items() for k, v in kwargs.items()
} }
values, fields_set, validation_error = pydantic.validate_model(self, kwargs) values, fields_set, validation_error = pydantic.validate_model(self, kwargs) # type: ignore
if validation_error and not pk_only: if validation_error and not pk_only:
raise validation_error raise validation_error
@ -134,13 +134,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if item in self._orm: if item in self._orm:
return self._orm.get(item) return self._orm.get(item)
return None
def __eq__(self, other: "Model") -> bool: def __eq__(self, other: object) -> bool:
if isinstance(other, NewBaseModel): if isinstance(other, NewBaseModel):
return self.__same__(other) return self.__same__(other)
return super().__eq__(other) # pragma no cover return super().__eq__(other) # pragma no cover
def __same__(self, other: "Model") -> bool: def __same__(self, other: "NewBaseModel") -> bool:
return ( return (
self._orm_id == other._orm_id self._orm_id == other._orm_id
or self.dict() == other.dict() or self.dict() == other.dict()
@ -205,19 +206,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
dict_instance[field] = None dict_instance[field] = None
return dict_instance return dict_instance
def from_dict(self, value_dict: Dict) -> "Model": def from_dict(self, value_dict: Dict) -> "NewBaseModel":
for key, value in value_dict.items(): for key, value in value_dict.items():
setattr(self, key, value) setattr(self, key, value)
return self return self
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, Dict]:
if not self._is_conversion_to_json_needed(column_name): if not self._is_conversion_to_json_needed(column_name):
return value return value
condition = ( condition = (
isinstance(value, str) if op == "loads" else not isinstance(value, str) isinstance(value, str) if op == "loads" else not isinstance(value, str)
) )
operand = json.loads if op == "loads" else json.dumps operand: Callable[[Any], Any] = json.loads if op == "loads" else json.dumps
if condition: if condition:
try: try:
@ -227,4 +228,4 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
return value return value
def _is_conversion_to_json_needed(self, column_name: str) -> bool: def _is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json return self.Meta.model_fields[column_name].__type__ == pydantic.Json

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -118,7 +118,7 @@ class QueryClause:
def _determine_filter_target_table( def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str] self, related_parts: List[str], select_related: List[str]
) -> Tuple[List[str], str, "Model"]: ) -> Tuple[List[str], str, Type["Model"]]:
table_prefix = "" table_prefix = ""
model_cls = self.model_cls model_cls = self.model_cls
@ -168,9 +168,7 @@ class QueryClause:
return clause return clause
@staticmethod @staticmethod
def _escape_characters_in_clause( def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
op: str, value: Union[str, "Model"]
) -> Tuple[str, bool]:
has_escaped_character = False has_escaped_character = False
if op not in [ if op not in [

View File

@ -22,8 +22,8 @@ class SqlJoin:
self, self,
used_aliases: List, used_aliases: List,
select_from: sqlalchemy.sql.select, select_from: sqlalchemy.sql.select,
order_bys: List, order_bys: List[sqlalchemy.sql.elements.TextClause],
columns: List, columns: List[sqlalchemy.Column],
) -> None: ) -> None:
self.used_aliases = used_aliases self.used_aliases = used_aliases
self.select_from = select_from self.select_from = select_from

View File

@ -1,8 +1,10 @@
from typing import Optional
import sqlalchemy import sqlalchemy
class LimitQuery: class LimitQuery:
def __init__(self, limit_count: int) -> None: def __init__(self, limit_count: Optional[int]) -> None:
self.limit_count = limit_count self.limit_count = limit_count
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:

View File

@ -1,8 +1,10 @@
from typing import Optional
import sqlalchemy import sqlalchemy
class OffsetQuery: class OffsetQuery:
def __init__(self, query_offset: int) -> None: def __init__(self, query_offset: Optional[int]) -> None:
self.query_offset = query_offset self.query_offset = query_offset
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:

View File

@ -1,4 +1,4 @@
from typing import List, TYPE_CHECKING, Tuple, Type from typing import List, TYPE_CHECKING, Tuple, Type, Optional
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -18,8 +18,8 @@ class Query:
filter_clauses: List, filter_clauses: List,
exclude_clauses: List, exclude_clauses: List,
select_related: List, select_related: List,
limit_count: int, limit_count: Optional[int],
offset: int, offset: Optional[int],
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
self.limit_count = limit_count self.limit_count = limit_count
@ -30,11 +30,11 @@ class Query:
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table
self.used_aliases = [] self.used_aliases: List[str] = []
self.select_from = None self.select_from: List[str] = []
self.columns = None self.columns = [sqlalchemy.Column]
self.order_bys = None self.order_bys: List[sqlalchemy.sql.elements.TextClause] = []
@property @property
def prefixed_pk_name(self) -> str: def prefixed_pk_name(self) -> str:
@ -89,7 +89,7 @@ class Query:
return expr return expr
def _reset_query_parameters(self) -> None: def _reset_query_parameters(self) -> None:
self.select_from = None self.select_from = []
self.columns = None self.columns = []
self.order_bys = None self.order_bys = []
self.used_aliases = [] self.used_aliases = []

View File

@ -1,4 +1,4 @@
from typing import Any, List, Mapping, TYPE_CHECKING, Tuple, Type, Union from typing import Any, List, Mapping, TYPE_CHECKING, Type, Union, Optional
import databases import databases
import sqlalchemy import sqlalchemy
@ -13,6 +13,7 @@ from ormar.queryset.query import Query
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.models.metaclass import ModelMeta
class QuerySet: class QuerySet:
@ -36,23 +37,36 @@ class QuerySet:
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
return self.__class__(model_cls=owner) return self.__class__(model_cls=owner)
def _process_query_result_rows(self, rows: List[Mapping]) -> List["Model"]: @property
def model_meta(self) -> "ModelMeta":
if not self.model_cls: # pragma nocover
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls.Meta
@property
def model(self) -> Type["Model"]:
if not self.model_cls: # pragma nocover
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
result_rows = [ result_rows = [
self.model_cls.from_row(row, select_related=self._select_related) self.model.from_row(row, select_related=self._select_related)
for row in rows for row in rows
] ]
rows = self.model_cls.merge_instances_list(result_rows) if result_rows:
return rows return self.model.merge_instances_list(result_rows) # type: ignore
return result_rows
def _populate_default_values(self, new_kwargs: dict) -> dict: def _populate_default_values(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_cls.Meta.model_fields.items(): for field_name, field in self.model_meta.model_fields.items():
if field_name not in new_kwargs and field.has_default(): if field_name not in new_kwargs and field.has_default():
new_kwargs[field_name] = field.get_default() new_kwargs[field_name] = field.get_default()
return new_kwargs return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_cls.Meta.pkname pkname = self.model_meta.pkname
pk = self.model_cls.Meta.model_fields[pkname] pk = self.model_meta.model_fields[pkname]
if new_kwargs.get(pkname, ormar.Undefined) is None and ( if new_kwargs.get(pkname, ormar.Undefined) is None and (
pk.nullable or pk.autoincrement pk.nullable or pk.autoincrement
): ):
@ -60,23 +74,23 @@ class QuerySet:
return new_kwargs return new_kwargs
@staticmethod @staticmethod
def check_single_result_rows_count(rows: List["Model"]) -> None: def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None:
if not rows: if not rows or rows[0] is None:
raise NoMatch() raise NoMatch()
if len(rows) > 1: if len(rows) > 1:
raise MultipleMatches() raise MultipleMatches()
@property @property
def database(self) -> databases.Database: def database(self) -> databases.Database:
return self.model_cls.Meta.database return self.model_meta.database
@property @property
def table(self) -> sqlalchemy.Table: def table(self) -> sqlalchemy.Table:
return self.model_cls.Meta.table return self.model_meta.table
def build_select_expression(self) -> sqlalchemy.sql.select: def build_select_expression(self) -> sqlalchemy.sql.select:
qry = Query( qry = Query(
model_cls=self.model_cls, model_cls=self.model,
select_related=self._select_related, select_related=self._select_related,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses, exclude_clauses=self.exclude_clauses,
@ -89,7 +103,7 @@ class QuerySet:
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
qryclause = QueryClause( qryclause = QueryClause(
model_cls=self.model_cls, model_cls=self.model,
select_related=self._select_related, select_related=self._select_related,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
) )
@ -102,7 +116,7 @@ class QuerySet:
filter_clauses = filter_clauses filter_clauses = filter_clauses
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model,
filter_clauses=filter_clauses, filter_clauses=filter_clauses,
exclude_clauses=exclude_clauses, exclude_clauses=exclude_clauses,
select_related=select_related, select_related=select_related,
@ -113,13 +127,13 @@ class QuerySet:
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.filter(_exclude=True, **kwargs) return self.filter(_exclude=True, **kwargs)
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": def select_related(self, related: Union[List, str]) -> "QuerySet":
if not isinstance(related, (list, tuple)): if not isinstance(related, list):
related = [related] related = [related]
related = list(set(list(self._select_related) + related)) related = list(set(list(self._select_related) + related))
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses, exclude_clauses=self.exclude_clauses,
select_related=related, select_related=related,
@ -138,7 +152,7 @@ class QuerySet:
return await self.database.fetch_val(expr) return await self.database.fetch_val(expr)
async def update(self, each: bool = False, **kwargs: Any) -> int: async def update(self, each: bool = False, **kwargs: Any) -> int:
self_fields = self.model_cls.extract_db_own_fields() self_fields = self.model.extract_db_own_fields()
updates = {k: v for k, v in kwargs.items() if k in self_fields} updates = {k: v for k, v in kwargs.items() if k in self_fields}
if not each and not self.filter_clauses: if not each and not self.filter_clauses:
raise QueryDefinitionError( raise QueryDefinitionError(
@ -165,7 +179,7 @@ class QuerySet:
def limit(self, limit_count: int) -> "QuerySet": def limit(self, limit_count: int) -> "QuerySet":
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses, exclude_clauses=self.exclude_clauses,
select_related=self._select_related, select_related=self._select_related,
@ -175,7 +189,7 @@ class QuerySet:
def offset(self, offset: int) -> "QuerySet": def offset(self, offset: int) -> "QuerySet":
return self.__class__( return self.__class__(
model_cls=self.model_cls, model_cls=self.model,
filter_clauses=self.filter_clauses, filter_clauses=self.filter_clauses,
exclude_clauses=self.exclude_clauses, exclude_clauses=self.exclude_clauses,
select_related=self._select_related, select_related=self._select_related,
@ -189,7 +203,7 @@ class QuerySet:
rows = await self.limit(1).all() rows = await self.limit(1).all()
self.check_single_result_rows_count(rows) self.check_single_result_rows_count(rows)
return rows[0] return rows[0] # type: ignore
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "Model":
if kwargs: if kwargs:
@ -200,9 +214,9 @@ class QuerySet:
expr = expr.limit(2) expr = expr.limit(2)
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
rows = self._process_query_result_rows(rows) processed_rows = self._process_query_result_rows(rows)
self.check_single_result_rows_count(rows) self.check_single_result_rows_count(processed_rows)
return rows[0] return processed_rows[0] # type: ignore
async def get_or_create(self, **kwargs: Any) -> "Model": async def get_or_create(self, **kwargs: Any) -> "Model":
try: try:
@ -211,7 +225,7 @@ class QuerySet:
return await self.create(**kwargs) return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model": async def update_or_create(self, **kwargs: Any) -> "Model":
pk_name = self.model_cls.Meta.pkname pk_name = self.model_meta.pkname
if "pk" in kwargs: if "pk" in kwargs:
kwargs[pk_name] = kwargs.pop("pk") kwargs[pk_name] = kwargs.pop("pk")
if pk_name not in kwargs or kwargs.get(pk_name) is None: if pk_name not in kwargs or kwargs.get(pk_name) is None:
@ -219,7 +233,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name]) model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs) return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
if kwargs: if kwargs:
return await self.filter(**kwargs).all() return await self.filter(**kwargs).all()
@ -233,20 +247,20 @@ class QuerySet:
new_kwargs = dict(**kwargs) new_kwargs = dict(**kwargs)
new_kwargs = self._remove_pk_from_kwargs(new_kwargs) new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs) new_kwargs = self._populate_default_values(new_kwargs)
expr = self.table.insert() expr = self.table.insert()
expr = expr.values(**new_kwargs) expr = expr.values(**new_kwargs)
# Execute the insert, and return a new model instance. # Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs) instance = self.model(**kwargs)
pk = await self.database.execute(expr) pk = await self.database.execute(expr)
pk_name = self.model_cls.Meta.pkname pk_name = self.model_meta.pkname
if pk_name not in kwargs and pk_name in new_kwargs: if pk_name not in kwargs and pk_name in new_kwargs:
instance.pk = new_kwargs[self.model_cls.Meta.pkname] instance.pk = new_kwargs[self.model_meta.pkname]
if pk and isinstance(pk, self.model_cls.pk_type()): if pk and isinstance(pk, self.model.pk_type()):
setattr(instance, self.model_cls.Meta.pkname, pk) setattr(instance, self.model_meta.pkname, pk)
return instance return instance
async def bulk_create(self, objects: List["Model"]) -> None: async def bulk_create(self, objects: List["Model"]) -> None:
@ -254,7 +268,7 @@ class QuerySet:
for objt in objects: for objt in objects:
new_kwargs = objt.dict() new_kwargs = objt.dict()
new_kwargs = self._remove_pk_from_kwargs(new_kwargs) new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs) new_kwargs = self._populate_default_values(new_kwargs)
ready_objects.append(new_kwargs) ready_objects.append(new_kwargs)
@ -265,10 +279,12 @@ class QuerySet:
self, objects: List["Model"], columns: List[str] = None self, objects: List["Model"], columns: List[str] = None
) -> None: ) -> None:
ready_objects = [] ready_objects = []
pk_name = self.model_cls.Meta.pkname pk_name = self.model_meta.pkname
if not columns: if not columns:
columns = self.model_cls.extract_db_own_fields().union( columns = list(
self.model_cls.extract_related_names() self.model.extract_db_own_fields().union(
self.model.extract_related_names()
)
) )
if pk_name not in columns: if pk_name not in columns:
@ -279,13 +295,13 @@ class QuerySet:
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot update unsaved objects. " "You cannot update unsaved objects. "
f"{self.model_cls.__name__} has to have {pk_name} filled." f"{self.model.__name__} has to have {pk_name} filled."
) )
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs) ready_objects.append(new_kwargs)
pk_column = self.model_cls.Meta.table.c.get(pk_name) pk_column = self.model_meta.table.c.get(pk_name)
expr = self.table.update().where(pk_column == bindparam("new_" + pk_name)) expr = self.table.update().where(pk_column == bindparam("new_" + pk_name))
expr = expr.values( expr = expr.values(
**{k: bindparam("new_" + k) for k in columns if k != pk_name} **{k: bindparam("new_" + k) for k in columns if k != pk_name}

View File

@ -1,7 +1,7 @@
import string import string
import uuid import uuid
from random import choices from random import choices
from typing import List from typing import List, Dict
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -14,7 +14,7 @@ def get_table_alias() -> str:
class AliasManager: class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._aliases = dict() self._aliases: Dict[str, str] = dict()
@staticmethod @staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:

View File

@ -13,8 +13,8 @@ class QuerysetProxy:
relation: "Relation" relation: "Relation"
def __init__(self, relation: "Relation") -> None: def __init__(self, relation: "Relation") -> None:
self.relation = relation self.relation: Relation = relation
self.queryset = None self.queryset: "QuerySet"
def _assign_child_to_parent(self, child: "Model") -> None: def _assign_child_to_parent(self, child: "Model") -> None:
owner = self.relation._owner owner = self.relation._owner

View File

@ -9,6 +9,7 @@ from ormar.relations.relation_proxy import RelationProxy
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.relations import RelationsManager from ormar.relations import RelationsManager
from ormar.models import NewBaseModel
class RelationType(Enum): class RelationType(Enum):
@ -26,17 +27,19 @@ class Relation:
through: Type["Model"] = None, through: Type["Model"] = None,
) -> None: ) -> None:
self.manager = manager self.manager = manager
self._owner = manager.owner self._owner: "Model" = manager.owner
self._type = type_ self._type: RelationType = type_
self.to = to self.to: Type["Model"] = to
self.through = through self.through: Optional[Type["Model"]] = through
self.related_models = ( self.related_models: Optional[Union[RelationProxy, "Model"]] = (
RelationProxy(relation=self) RelationProxy(relation=self)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
) )
def _find_existing(self, child: "Model") -> Optional[int]: def _find_existing(self, child: "Model") -> Optional[int]:
if not isinstance(self.related_models, RelationProxy): # pragma nocover
raise ValueError("Cannot find existing models in parent relation type")
for ind, relation_child in enumerate(self.related_models[:]): for ind, relation_child in enumerate(self.related_models[:]):
try: try:
if relation_child == child: if relation_child == child:
@ -52,7 +55,7 @@ class Relation:
self._owner.__dict__[relation_name] = child self._owner.__dict__[relation_name] = child
else: else:
if self._find_existing(child) is None: if self._find_existing(child) is None:
self.related_models.append(child) self.related_models.append(child) # type: ignore
rel = self._owner.__dict__.get(relation_name, []) rel = self._owner.__dict__.get(relation_name, [])
rel = rel or [] rel = rel or []
if not isinstance(rel, list): if not isinstance(rel, list):
@ -60,19 +63,19 @@ class Relation:
rel.append(child) rel.append(child)
self._owner.__dict__[relation_name] = rel self._owner.__dict__[relation_name] = rel
def remove(self, child: "Model") -> None: def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child) relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY: if self._type == RelationType.PRIMARY:
if self.related_models.__same__(child): if self.related_models == child:
self.related_models = None self.related_models = None
del self._owner.__dict__[relation_name] del self._owner.__dict__[relation_name]
else: else:
position = self._find_existing(child) position = self._find_existing(child)
if position is not None: if position is not None:
self.related_models.pop(position) self.related_models.pop(position) # type: ignore
del self._owner.__dict__[relation_name][position] del self._owner.__dict__[relation_name][position]
def get(self) -> Union[List["Model"], "Model"]: def get(self) -> Optional[Union[List["Model"], "Model"]]:
return self.related_models return self.related_models
def __repr__(self) -> str: # pragma no cover def __repr__(self) -> str: # pragma no cover

View File

@ -1,6 +1,7 @@
from typing import List, Optional, TYPE_CHECKING, Type, Union from typing import List, Optional, TYPE_CHECKING, Type, Union, Dict
from weakref import proxy from weakref import proxy
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.relation import Relation, RelationType from ormar.relations.relation import Relation, RelationType
@ -11,25 +12,28 @@ from ormar.relations.utils import (
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.models import NewBaseModel
class RelationsManager: class RelationsManager:
def __init__( def __init__(
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None self,
related_fields: List[Type[ForeignKeyField]] = None,
owner: "NewBaseModel" = None,
) -> None: ) -> None:
self.owner = proxy(owner) self.owner = proxy(owner)
self._related_fields = related_fields or [] self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields] self._related_names = [field.name for field in self._related_fields]
self._relations = dict() self._relations: Dict[str, Relation] = dict()
for field in self._related_fields: for field in self._related_fields:
self._add_relation(field) self._add_relation(field)
def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType: def _get_relation_type(self, field: Type[BaseField]) -> RelationType:
if issubclass(field, ManyToManyField): if issubclass(field, ManyToManyField):
return RelationType.MULTIPLE return RelationType.MULTIPLE
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
def _add_relation(self, field: Type[ForeignKeyField]) -> None: def _add_relation(self, field: Type[BaseField]) -> None:
self._relations[field.name] = Relation( self._relations[field.name] = Relation(
manager=self, manager=self,
type_=self._get_relation_type(field), type_=self._get_relation_type(field),
@ -44,15 +48,17 @@ class RelationsManager:
relation = self._relations.get(name, None) relation = self._relations.get(name, None)
if relation is not None: if relation is not None:
return relation.get() return relation.get()
return None # pragma nocover
def _get(self, name: str) -> Optional[Relation]: def _get(self, name: str) -> Optional[Relation]:
relation = self._relations.get(name, None) relation = self._relations.get(name, None)
if relation is not None: if relation is not None:
return relation return relation
return None
@staticmethod @staticmethod
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
to_field = child.resolve_relation_field(child, parent) to_field: Type[BaseField] = child.resolve_relation_field(child, parent)
(parent, child, child_name, to_name,) = get_relations_sides_and_names( (parent, child, child_name, to_name,) = get_relations_sides_and_names(
to_field, parent, child, child_name, virtual to_field, parent, child, child_name, virtual
@ -61,18 +67,22 @@ class RelationsManager:
parent_relation = parent._orm._get(child_name) parent_relation = parent._orm._get(child_name)
if not parent_relation: if not parent_relation:
parent_relation = register_missing_relation(parent, child, child_name) parent_relation = register_missing_relation(parent, child, child_name)
parent_relation.add(child) parent_relation.add(child) # type: ignore
child._orm._get(to_name).add(parent)
def remove(self, name: str, child: "Model") -> None: child_relation = child._orm._get(to_name)
if child_relation:
child_relation.add(parent)
def remove(self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
relation = self._get(name) relation = self._get(name)
if relation:
relation.remove(child) relation.remove(child)
@staticmethod @staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None: def remove_parent(item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model") -> None:
related_model = name related_model = name
name = item.resolve_relation_name(item, related_model) rel_name = item.resolve_relation_name(item, related_model)
if name in item._orm: if rel_name in item._orm:
relation_name = item.resolve_relation_name(related_model, item) relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model) item._orm.remove(rel_name, related_model)
related_model._orm.remove(relation_name, item) related_model._orm.remove(relation_name, item)

View File

@ -13,22 +13,30 @@ if TYPE_CHECKING: # pragma no cover
class RelationProxy(list): class RelationProxy(list):
def __init__(self, relation: "Relation") -> None: def __init__(self, relation: "Relation") -> None:
super(RelationProxy, self).__init__() super(RelationProxy, self).__init__()
self.relation = relation self.relation: Relation = relation
self._owner = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation) self.queryset_proxy = QuerysetProxy(relation=self.relation)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]: if item in ["count", "clear"]:
if not self.queryset_proxy.queryset: self._initialize_queryset()
self.queryset_proxy.queryset = self._set_queryset()
return getattr(self.queryset_proxy, item) return getattr(self.queryset_proxy, item)
return super().__getattribute__(item) return super().__getattribute__(item)
def __getattr__(self, item: str) -> Any: def __getattr__(self, item: str) -> Any:
if not self.queryset_proxy.queryset: self._initialize_queryset()
self.queryset_proxy.queryset = self._set_queryset()
return getattr(self.queryset_proxy, item) return getattr(self.queryset_proxy, item)
def _initialize_queryset(self) -> None:
if not self._check_if_queryset_is_initialized():
self.queryset_proxy.queryset = self._set_queryset()
def _check_if_queryset_is_initialized(self) -> bool:
return (
hasattr(self.queryset_proxy, "queryset")
and self.queryset_proxy.queryset is not None
)
def _set_queryset(self) -> "QuerySet": def _set_queryset(self) -> "QuerySet":
owner_table = self.relation._owner.Meta.tablename owner_table = self.relation._owner.Meta.tablename
pkname = self.relation._owner.Meta.pkname pkname = self.relation._owner.Meta.pkname
@ -45,10 +53,15 @@ class RelationProxy(list):
) )
return queryset return queryset
async def remove(self, item: "Model") -> None: async def remove(self, item: "Model") -> None: # type: ignore
super().remove(item) super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner) rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner) relation = item._orm._get(rel_name)
if relation is None: # pragma nocover
raise ValueError(
f"{self._owner.get_name()} does not have relation {rel_name}"
)
relation.remove(self._owner)
if self.relation._type == ormar.RelationType.MULTIPLE: if self.relation._type == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.delete_through_instance(item) await self.queryset_proxy.delete_through_instance(item)

View File

@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, Tuple, Type from typing import TYPE_CHECKING, Tuple, Type, Optional
from weakref import proxy from weakref import proxy
import ormar import ormar
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields import BaseField
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.relations import Relation from ormar.relations import Relation
@ -12,7 +12,7 @@ if TYPE_CHECKING: # pragma no cover
def register_missing_relation( def register_missing_relation(
parent: "Model", child: "Model", child_name: str parent: "Model", child: "Model", child_name: str
) -> Relation: ) -> Optional[Relation]:
ormar.models.expand_reverse_relationships(child.__class__) ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child) name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name] field = parent.Meta.model_fields[name]
@ -22,7 +22,7 @@ def register_missing_relation(
def get_relations_sides_and_names( def get_relations_sides_and_names(
to_field: Type[ForeignKeyField], to_field: Type[BaseField],
parent: "Model", parent: "Model",
child: "Model", child: "Model",
child_name: str, child_name: str,