added mypy checks and some typehint changes to conform
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
|||||||
p38venv
|
p38venv
|
||||||
.idea
|
.idea
|
||||||
.pytest_cache
|
.pytest_cache
|
||||||
|
.mypy_cache
|
||||||
*.pyc
|
*.pyc
|
||||||
*.log
|
*.log
|
||||||
test.db
|
test.db
|
||||||
|
|||||||
5
mypy.ini
Normal file
5
mypy.ini
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[mypy]
|
||||||
|
python_version = 3.8
|
||||||
|
|
||||||
|
[mypy-sqlalchemy.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
@ -1,7 +1,8 @@
|
|||||||
from typing import Any, List, Optional, TYPE_CHECKING, Union
|
from typing import Any, List, Optional, TYPE_CHECKING, Union, Type
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from pydantic import Field, typing
|
from pydantic import Field, typing
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
from ormar import ModelDefinitionError # noqa I101
|
from ormar import ModelDefinitionError # noqa I101
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ class BaseField:
|
|||||||
|
|
||||||
column_type: sqlalchemy.Column
|
column_type: sqlalchemy.Column
|
||||||
constraints: List = []
|
constraints: List = []
|
||||||
|
name: str
|
||||||
|
|
||||||
primary_key: bool
|
primary_key: bool
|
||||||
autoincrement: bool
|
autoincrement: bool
|
||||||
@ -24,12 +26,14 @@ class BaseField:
|
|||||||
pydantic_only: bool
|
pydantic_only: bool
|
||||||
virtual: bool = False
|
virtual: bool = False
|
||||||
choices: typing.Sequence
|
choices: typing.Sequence
|
||||||
|
to: Type["Model"]
|
||||||
|
through: Type["Model"]
|
||||||
|
|
||||||
default: Any
|
default: Any
|
||||||
server_default: Any
|
server_default: Any
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_value(cls) -> Optional[Field]:
|
def default_value(cls) -> Optional[FieldInfo]:
|
||||||
if cls.is_auto_primary_key():
|
if cls.is_auto_primary_key():
|
||||||
return Field(default=None)
|
return Field(default=None)
|
||||||
if cls.has_default():
|
if cls.has_default():
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union
|
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Generator
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
@ -7,7 +7,7 @@ from ormar.exceptions import RelationshipInstanceError
|
|||||||
from ormar.fields.base import BaseField
|
from ormar.fields.base import BaseField
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar.models import Model
|
from ormar.models import Model, NewBaseModel
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
|
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
|
||||||
@ -23,16 +23,16 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
|
|||||||
|
|
||||||
|
|
||||||
def ForeignKey( # noqa CFQ002
|
def ForeignKey( # noqa CFQ002
|
||||||
to: Type["Model"],
|
to: Type["Model"],
|
||||||
*,
|
*,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
unique: bool = False,
|
unique: bool = False,
|
||||||
nullable: bool = True,
|
nullable: bool = True,
|
||||||
related_name: str = None,
|
related_name: str = None,
|
||||||
virtual: bool = False,
|
virtual: bool = False,
|
||||||
onupdate: str = None,
|
onupdate: str = None,
|
||||||
ondelete: str = None,
|
ondelete: str = None,
|
||||||
) -> Type[object]:
|
) -> Type["ForeignKeyField"]:
|
||||||
fk_string = to.Meta.tablename + "." + to.Meta.pkname
|
fk_string = to.Meta.tablename + "." + to.Meta.pkname
|
||||||
to_field = to.__fields__[to.Meta.pkname]
|
to_field = to.__fields__[to.Meta.pkname]
|
||||||
namespace = dict(
|
namespace = dict(
|
||||||
@ -65,7 +65,7 @@ class ForeignKeyField(BaseField):
|
|||||||
virtual: bool
|
virtual: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_validators__(cls) -> Callable:
|
def __get_validators__(cls) -> Generator:
|
||||||
yield cls.validate
|
yield cls.validate
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -74,13 +74,13 @@ class ForeignKeyField(BaseField):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_model_from_sequence(
|
def _extract_model_from_sequence(
|
||||||
cls, value: List, child: "Model", to_register: bool
|
cls, value: List, child: "Model", to_register: bool
|
||||||
) -> Union["Model", List["Model"]]:
|
) -> List["Model"]:
|
||||||
return [cls.expand_relationship(val, child, to_register) for val in value]
|
return [cls.expand_relationship(val, child, to_register) for val in value] # type: ignore
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _register_existing_model(
|
def _register_existing_model(
|
||||||
cls, value: "Model", child: "Model", to_register: bool
|
cls, value: "Model", child: "Model", to_register: bool
|
||||||
) -> "Model":
|
) -> "Model":
|
||||||
if to_register:
|
if to_register:
|
||||||
cls.register_relation(value, child)
|
cls.register_relation(value, child)
|
||||||
@ -88,7 +88,7 @@ class ForeignKeyField(BaseField):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _construct_model_from_dict(
|
def _construct_model_from_dict(
|
||||||
cls, value: dict, child: "Model", to_register: bool
|
cls, value: dict, child: "Model", to_register: bool
|
||||||
) -> "Model":
|
) -> "Model":
|
||||||
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname:
|
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname:
|
||||||
value["__pk_only__"] = True
|
value["__pk_only__"] = True
|
||||||
@ -99,7 +99,7 @@ class ForeignKeyField(BaseField):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _construct_model_from_pk(
|
def _construct_model_from_pk(
|
||||||
cls, value: Any, child: "Model", to_register: bool
|
cls, value: Any, child: "Model", to_register: bool
|
||||||
) -> "Model":
|
) -> "Model":
|
||||||
if not isinstance(value, cls.to.pk_type()):
|
if not isinstance(value, cls.to.pk_type()):
|
||||||
raise RelationshipInstanceError(
|
raise RelationshipInstanceError(
|
||||||
@ -120,7 +120,7 @@ class ForeignKeyField(BaseField):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def expand_relationship(
|
def expand_relationship(
|
||||||
cls, value: Any, child: "Model", to_register: bool = True
|
cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True
|
||||||
) -> Optional[Union["Model", List["Model"]]]:
|
) -> Optional[Union["Model", List["Model"]]]:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None if not cls.virtual else []
|
return None if not cls.virtual else []
|
||||||
@ -131,7 +131,7 @@ class ForeignKeyField(BaseField):
|
|||||||
"list": cls._extract_model_from_sequence,
|
"list": cls._extract_model_from_sequence,
|
||||||
}
|
}
|
||||||
|
|
||||||
model = constructors.get(
|
model = constructors.get( # type: ignore
|
||||||
value.__class__.__name__, cls._construct_model_from_pk
|
value.__class__.__name__, cls._construct_model_from_pk
|
||||||
)(value, child, to_register)
|
)(value, child, to_register)
|
||||||
return model
|
return model
|
||||||
|
|||||||
@ -15,7 +15,7 @@ def ManyToMany(
|
|||||||
unique: bool = False,
|
unique: bool = False,
|
||||||
related_name: str = None,
|
related_name: str = None,
|
||||||
virtual: bool = False,
|
virtual: bool = False,
|
||||||
) -> Type[object]:
|
) -> Type["ManyToManyField"]:
|
||||||
to_field = to.__fields__[to.Meta.pkname]
|
to_field = to.__fields__[to.Meta.pkname]
|
||||||
namespace = dict(
|
namespace = dict(
|
||||||
to=to,
|
to=to,
|
||||||
|
|||||||
@ -18,10 +18,10 @@ def is_field_nullable(
|
|||||||
|
|
||||||
|
|
||||||
class ModelFieldFactory:
|
class ModelFieldFactory:
|
||||||
_bases = BaseField
|
_bases: Any = BaseField
|
||||||
_type = None
|
_type: Any = None
|
||||||
|
|
||||||
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
|
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore
|
||||||
cls.validate(**kwargs)
|
cls.validate(**kwargs)
|
||||||
|
|
||||||
default = kwargs.pop("default", None)
|
default = kwargs.pop("default", None)
|
||||||
@ -58,7 +58,7 @@ class String(ModelFieldFactory):
|
|||||||
_bases = (pydantic.ConstrainedStr, BaseField)
|
_bases = (pydantic.ConstrainedStr, BaseField)
|
||||||
_type = str
|
_type = str
|
||||||
|
|
||||||
def __new__( # noqa CFQ002
|
def __new__( # type: ignore # noqa CFQ002
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
allow_blank: bool = False,
|
allow_blank: bool = False,
|
||||||
@ -68,7 +68,7 @@ class String(ModelFieldFactory):
|
|||||||
curtail_length: int = None,
|
curtail_length: int = None,
|
||||||
regex: str = None,
|
regex: str = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Type[str]:
|
) -> Type[BaseField]: # type: ignore
|
||||||
kwargs = {
|
kwargs = {
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**{
|
**{
|
||||||
@ -96,14 +96,14 @@ class Integer(ModelFieldFactory):
|
|||||||
_bases = (pydantic.ConstrainedInt, BaseField)
|
_bases = (pydantic.ConstrainedInt, BaseField)
|
||||||
_type = int
|
_type = int
|
||||||
|
|
||||||
def __new__(
|
def __new__( # type: ignore
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
minimum: int = None,
|
minimum: int = None,
|
||||||
maximum: int = None,
|
maximum: int = None,
|
||||||
multiple_of: int = None,
|
multiple_of: int = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Type[int]:
|
) -> Type[BaseField]:
|
||||||
autoincrement = kwargs.pop("autoincrement", None)
|
autoincrement = kwargs.pop("autoincrement", None)
|
||||||
autoincrement = (
|
autoincrement = (
|
||||||
autoincrement
|
autoincrement
|
||||||
@ -131,9 +131,9 @@ class Text(ModelFieldFactory):
|
|||||||
_bases = (pydantic.ConstrainedStr, BaseField)
|
_bases = (pydantic.ConstrainedStr, BaseField)
|
||||||
_type = str
|
_type = str
|
||||||
|
|
||||||
def __new__(
|
def __new__( # type: ignore
|
||||||
cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
|
cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
|
||||||
) -> Type[str]:
|
) -> Type[BaseField]:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**{
|
**{
|
||||||
@ -153,14 +153,14 @@ class Float(ModelFieldFactory):
|
|||||||
_bases = (pydantic.ConstrainedFloat, BaseField)
|
_bases = (pydantic.ConstrainedFloat, BaseField)
|
||||||
_type = float
|
_type = float
|
||||||
|
|
||||||
def __new__(
|
def __new__( # type: ignore
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
minimum: float = None,
|
minimum: float = None,
|
||||||
maximum: float = None,
|
maximum: float = None,
|
||||||
multiple_of: int = None,
|
multiple_of: int = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Type[int]:
|
) -> Type[BaseField]:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**{
|
**{
|
||||||
@ -236,7 +236,7 @@ class Decimal(ModelFieldFactory):
|
|||||||
_bases = (pydantic.ConstrainedDecimal, BaseField)
|
_bases = (pydantic.ConstrainedDecimal, BaseField)
|
||||||
_type = decimal.Decimal
|
_type = decimal.Decimal
|
||||||
|
|
||||||
def __new__( # noqa CFQ002
|
def __new__( # type: ignore # noqa CFQ002
|
||||||
cls,
|
cls,
|
||||||
*,
|
*,
|
||||||
minimum: float = None,
|
minimum: float = None,
|
||||||
@ -247,7 +247,7 @@ class Decimal(ModelFieldFactory):
|
|||||||
max_digits: int = None,
|
max_digits: int = None,
|
||||||
decimal_places: int = None,
|
decimal_places: int = None,
|
||||||
**kwargs: Any
|
**kwargs: Any
|
||||||
) -> Type[decimal.Decimal]:
|
) -> Type[BaseField]:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**{
|
**{
|
||||||
|
|||||||
@ -28,15 +28,19 @@ class ModelMeta:
|
|||||||
database: databases.Database
|
database: databases.Database
|
||||||
columns: List[sqlalchemy.Column]
|
columns: List[sqlalchemy.Column]
|
||||||
pkname: str
|
pkname: str
|
||||||
model_fields: Dict[str, Union[BaseField, ForeignKey]]
|
model_fields: Dict[
|
||||||
|
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
|
||||||
|
]
|
||||||
alias_manager: AliasManager
|
alias_manager: AliasManager
|
||||||
|
|
||||||
|
|
||||||
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
|
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None:
|
||||||
alias_manager.add_relation_type(field.to.Meta.tablename, table_name)
|
alias_manager.add_relation_type(field.to.Meta.tablename, table_name)
|
||||||
|
|
||||||
|
|
||||||
def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None:
|
def register_many_to_many_relation_on_build(
|
||||||
|
table_name: str, field: Type[ManyToManyField]
|
||||||
|
) -> None:
|
||||||
alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
|
alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
|
||||||
alias_manager.add_relation_type(
|
alias_manager.add_relation_type(
|
||||||
field.through.Meta.tablename, field.to.Meta.tablename
|
field.through.Meta.tablename, field.to.Meta.tablename
|
||||||
@ -106,7 +110,7 @@ def create_pydantic_field(
|
|||||||
) -> None:
|
) -> None:
|
||||||
model_field.through.__fields__[field_name] = ModelField(
|
model_field.through.__fields__[field_name] = ModelField(
|
||||||
name=field_name,
|
name=field_name,
|
||||||
type_=Optional[model],
|
type_=model,
|
||||||
model_config=model.__config__,
|
model_config=model.__config__,
|
||||||
required=False,
|
required=False,
|
||||||
class_validators={},
|
class_validators={},
|
||||||
@ -130,7 +134,7 @@ def create_and_append_m2m_fk(
|
|||||||
|
|
||||||
|
|
||||||
def check_pk_column_validity(
|
def check_pk_column_validity(
|
||||||
field_name: str, field: BaseField, pkname: str
|
field_name: str, field: BaseField, pkname: Optional[str]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if pkname is not None:
|
if pkname is not None:
|
||||||
raise ModelDefinitionError("Only one primary key column is allowed.")
|
raise ModelDefinitionError("Only one primary key column is allowed.")
|
||||||
@ -218,6 +222,7 @@ def populate_meta_tablename_columns_and_pk(
|
|||||||
) -> Type["Model"]:
|
) -> Type["Model"]:
|
||||||
tablename = name.lower() + "s"
|
tablename = name.lower() + "s"
|
||||||
new_model.Meta.tablename = new_model.Meta.tablename or tablename
|
new_model.Meta.tablename = new_model.Meta.tablename or tablename
|
||||||
|
pkname: Optional[str]
|
||||||
|
|
||||||
if hasattr(new_model.Meta, "columns"):
|
if hasattr(new_model.Meta, "columns"):
|
||||||
columns = new_model.Meta.table.columns
|
columns = new_model.Meta.table.columns
|
||||||
@ -226,12 +231,13 @@ def populate_meta_tablename_columns_and_pk(
|
|||||||
pkname, columns = sqlalchemy_columns_from_model_fields(
|
pkname, columns = sqlalchemy_columns_from_model_fields(
|
||||||
new_model.Meta.model_fields, new_model.Meta.tablename
|
new_model.Meta.model_fields, new_model.Meta.tablename
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if pkname is None:
|
||||||
|
raise ModelDefinitionError("Table has to have a primary key.")
|
||||||
|
|
||||||
new_model.Meta.columns = columns
|
new_model.Meta.columns = columns
|
||||||
new_model.Meta.pkname = pkname
|
new_model.Meta.pkname = pkname
|
||||||
|
|
||||||
if not new_model.Meta.pkname:
|
|
||||||
raise ModelDefinitionError("Table has to have a primary key.")
|
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
|
|
||||||
@ -253,8 +259,8 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
|||||||
return Config
|
return Config
|
||||||
|
|
||||||
|
|
||||||
def check_if_field_has_choices(field: BaseField) -> bool:
|
def check_if_field_has_choices(field: Type[BaseField]) -> bool:
|
||||||
return hasattr(field, "choices") and field.choices
|
return hasattr(field, "choices") and bool(field.choices)
|
||||||
|
|
||||||
|
|
||||||
def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool:
|
def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool:
|
||||||
@ -287,7 +293,7 @@ def populate_choices_validators( # noqa CCR001
|
|||||||
|
|
||||||
|
|
||||||
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||||
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
|
def __new__(mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict) -> "ModelMetaclass": # type: ignore
|
||||||
attrs["Config"] = get_pydantic_base_orm_config()
|
attrs["Config"] = get_pydantic_base_orm_config()
|
||||||
attrs["__name__"] = name
|
attrs["__name__"] = name
|
||||||
attrs = extract_annotations_and_default_vals(attrs, bases)
|
attrs = extract_annotations_and_default_vals(attrs, bases)
|
||||||
@ -306,7 +312,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
|||||||
field_name = new_model.Meta.pkname
|
field_name = new_model.Meta.pkname
|
||||||
field = Integer(name=field_name, primary_key=True)
|
field = Integer(name=field_name, primary_key=True)
|
||||||
attrs["__annotations__"][field_name] = field
|
attrs["__annotations__"][field_name] = field
|
||||||
populate_default_pydantic_field_value(field, field_name, attrs)
|
populate_default_pydantic_field_value(field, field_name, attrs) # type: ignore
|
||||||
|
|
||||||
new_model = super().__new__( # type: ignore
|
new_model = super().__new__( # type: ignore
|
||||||
mcs, name, bases, attrs
|
mcs, name, bases, attrs
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Dict, Optional
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from databases.backends.postgres import Record
|
from databases.backends.postgres import Record
|
||||||
@ -9,8 +9,8 @@ from ormar.fields.many_to_many import ManyToManyField
|
|||||||
from ormar.models import NewBaseModel # noqa I100
|
from ormar.models import NewBaseModel # noqa I100
|
||||||
|
|
||||||
|
|
||||||
def group_related_list(list_: List) -> dict:
|
def group_related_list(list_: List) -> Dict:
|
||||||
test_dict = dict()
|
test_dict: Dict[str, Any] = dict()
|
||||||
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
|
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
|
||||||
for key, group in grouped:
|
for key, group in grouped:
|
||||||
group_list = list(group)
|
group_list = list(group)
|
||||||
@ -29,14 +29,14 @@ class Model(NewBaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_row(
|
def from_row(
|
||||||
cls,
|
cls,
|
||||||
row: sqlalchemy.engine.ResultProxy,
|
row: sqlalchemy.engine.ResultProxy,
|
||||||
select_related: List = None,
|
select_related: List = None,
|
||||||
related_models: Any = None,
|
related_models: Any = None,
|
||||||
previous_table: str = None,
|
previous_table: str = None,
|
||||||
) -> Union["Model", Tuple["Model", dict]]:
|
) -> Optional["Model"]:
|
||||||
|
|
||||||
item = {}
|
item: Dict[str, Any] = {}
|
||||||
select_related = select_related or []
|
select_related = select_related or []
|
||||||
related_models = related_models or []
|
related_models = related_models or []
|
||||||
if select_related:
|
if select_related:
|
||||||
@ -44,17 +44,20 @@ class Model(NewBaseModel):
|
|||||||
|
|
||||||
# breakpoint()
|
# breakpoint()
|
||||||
if (
|
if (
|
||||||
previous_table
|
previous_table
|
||||||
and previous_table in cls.Meta.model_fields
|
and previous_table in cls.Meta.model_fields
|
||||||
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
|
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
|
||||||
):
|
):
|
||||||
previous_table = cls.Meta.model_fields[
|
previous_table = cls.Meta.model_fields[
|
||||||
previous_table
|
previous_table
|
||||||
].through.Meta.tablename
|
].through.Meta.tablename
|
||||||
|
|
||||||
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
|
if previous_table:
|
||||||
previous_table, cls.Meta.table.name
|
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
|
||||||
)
|
previous_table, cls.Meta.table.name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
table_prefix = ''
|
||||||
previous_table = cls.Meta.table.name
|
previous_table = cls.Meta.table.name
|
||||||
|
|
||||||
item = cls.populate_nested_models_from_row(
|
item = cls.populate_nested_models_from_row(
|
||||||
@ -67,11 +70,11 @@ class Model(NewBaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def populate_nested_models_from_row(
|
def populate_nested_models_from_row(
|
||||||
cls,
|
cls,
|
||||||
item: dict,
|
item: dict,
|
||||||
row: sqlalchemy.engine.ResultProxy,
|
row: sqlalchemy.engine.ResultProxy,
|
||||||
related_models: Any,
|
related_models: Any,
|
||||||
previous_table: sqlalchemy.Table,
|
previous_table: sqlalchemy.Table,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
for related in related_models:
|
for related in related_models:
|
||||||
if isinstance(related_models, dict) and related_models[related]:
|
if isinstance(related_models, dict) and related_models[related]:
|
||||||
@ -90,7 +93,7 @@ class Model(NewBaseModel):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_prefixed_table_columns( # noqa CCR001
|
def extract_prefixed_table_columns( # noqa CCR001
|
||||||
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
|
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
for column in cls.Meta.table.columns:
|
for column in cls.Meta.table.columns:
|
||||||
if column.name not in item:
|
if column.name not in item:
|
||||||
@ -106,7 +109,7 @@ class Model(NewBaseModel):
|
|||||||
async def save(self) -> "Model":
|
async def save(self) -> "Model":
|
||||||
self_fields = self._extract_model_db_fields()
|
self_fields = self._extract_model_db_fields()
|
||||||
|
|
||||||
if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
|
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
||||||
self_fields.pop(self.Meta.pkname, None)
|
self_fields.pop(self.Meta.pkname, None)
|
||||||
self_fields = self.objects._populate_default_values(self_fields)
|
self_fields = self.objects._populate_default_values(self_fields)
|
||||||
expr = self.Meta.table.insert()
|
expr = self.Meta.table.insert()
|
||||||
@ -138,5 +141,7 @@ class Model(NewBaseModel):
|
|||||||
async def load(self) -> "Model":
|
async def load(self) -> "Model":
|
||||||
expr = self.Meta.table.select().where(self.pk_column == self.pk)
|
expr = self.Meta.table.select().where(self.pk_column == self.pk)
|
||||||
row = await self.Meta.database.fetch_one(expr)
|
row = await self.Meta.database.fetch_one(expr)
|
||||||
|
if not row: # pragma nocover
|
||||||
|
raise ValueError('Instance was deleted from database and cannot be refreshed')
|
||||||
self.from_dict(dict(row))
|
self.from_dict(dict(row))
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
|
from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union, Dict
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar.exceptions import RelationshipInstanceError
|
from ormar.exceptions import RelationshipInstanceError
|
||||||
@ -9,6 +9,7 @@ from ormar.models.metaclass import ModelMeta
|
|||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
from ormar.models import NewBaseModel
|
||||||
|
|
||||||
Field = TypeVar("Field", bound=BaseField)
|
Field = TypeVar("Field", bound=BaseField)
|
||||||
|
|
||||||
@ -17,10 +18,10 @@ class ModelTableProxy:
|
|||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
Meta: ModelMeta
|
Meta: ModelMeta
|
||||||
|
|
||||||
def dict(): # noqa A003
|
def dict(self): # noqa A003
|
||||||
raise NotImplementedError # pragma no cover
|
raise NotImplementedError # pragma no cover
|
||||||
|
|
||||||
def _extract_own_model_fields(self) -> dict:
|
def _extract_own_model_fields(self) -> Dict:
|
||||||
related_names = self.extract_related_names()
|
related_names = self.extract_related_names()
|
||||||
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
|
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
|
||||||
return self_fields
|
return self_fields
|
||||||
@ -34,7 +35,7 @@ class ModelTableProxy:
|
|||||||
return self_fields
|
return self_fields
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
|
def substitute_models_with_pks(cls, model_dict: Dict) -> Dict:
|
||||||
for field in cls.extract_related_names():
|
for field in cls.extract_related_names():
|
||||||
field_value = model_dict.get(field, None)
|
field_value = model_dict.get(field, None)
|
||||||
if field_value is not None:
|
if field_value is not None:
|
||||||
@ -80,7 +81,7 @@ class ModelTableProxy:
|
|||||||
related_names.add(name)
|
related_names.add(name)
|
||||||
return related_names
|
return related_names
|
||||||
|
|
||||||
def _extract_model_db_fields(self) -> dict:
|
def _extract_model_db_fields(self) -> Dict:
|
||||||
self_fields = self._extract_own_model_fields()
|
self_fields = self._extract_own_model_fields()
|
||||||
self_fields = {
|
self_fields = {
|
||||||
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
|
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
|
||||||
@ -92,7 +93,9 @@ class ModelTableProxy:
|
|||||||
return self_fields
|
return self_fields
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_relation_name(item: "Model", related: "Model") -> Optional[str]:
|
def resolve_relation_name(
|
||||||
|
item: Union["NewBaseModel", Type["NewBaseModel"]], related: Union["NewBaseModel", Type["NewBaseModel"]]
|
||||||
|
) -> str:
|
||||||
for name, field in item.Meta.model_fields.items():
|
for name, field in item.Meta.model_fields.items():
|
||||||
if issubclass(field, ForeignKeyField):
|
if issubclass(field, ForeignKeyField):
|
||||||
# fastapi is creating clones of response model
|
# fastapi is creating clones of response model
|
||||||
@ -100,11 +103,14 @@ class ModelTableProxy:
|
|||||||
# so we need to compare Meta too as this one is copied as is
|
# so we need to compare Meta too as this one is copied as is
|
||||||
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
||||||
return name
|
return name
|
||||||
|
raise ValueError(
|
||||||
|
f"No relation between {item.get_name()} and {related.get_name()}"
|
||||||
|
) # pragma nocover
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_relation_field(
|
def resolve_relation_field(
|
||||||
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
|
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
|
||||||
) -> Type[Field]:
|
) -> Union[Type[BaseField], Type[ForeignKeyField]]:
|
||||||
name = ModelTableProxy.resolve_relation_name(item, related)
|
name = ModelTableProxy.resolve_relation_name(item, related)
|
||||||
to_field = item.Meta.model_fields.get(name)
|
to_field = item.Meta.model_fields.get(name)
|
||||||
if not to_field: # pragma no cover
|
if not to_field: # pragma no cover
|
||||||
@ -116,7 +122,7 @@ class ModelTableProxy:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
||||||
merged_rows = []
|
merged_rows: List["Model"] = []
|
||||||
for index, model in enumerate(result_rows):
|
for index, model in enumerate(result_rows):
|
||||||
if index > 0 and model.pk == merged_rows[-1].pk:
|
if index > 0 and model.pk == merged_rows[-1].pk:
|
||||||
merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
|
merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
|
||||||
|
|||||||
@ -3,13 +3,13 @@ import uuid
|
|||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
__slots__ = ("_orm_id", "_orm_saved", "_orm")
|
__slots__ = ("_orm_id", "_orm_saved", "_orm")
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
__model_fields__: Dict[str, TypeVar[BaseField]]
|
__model_fields__: Dict[str, Type[BaseField]]
|
||||||
__table__: sqlalchemy.Table
|
__table__: sqlalchemy.Table
|
||||||
__fields__: Dict[str, pydantic.fields.ModelField]
|
__fields__: Dict[str, pydantic.fields.ModelField]
|
||||||
__pydantic_model__: Type[BaseModel]
|
__pydantic_model__: Type[BaseModel]
|
||||||
@ -84,7 +84,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
for k, v in kwargs.items()
|
for k, v in kwargs.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
values, fields_set, validation_error = pydantic.validate_model(self, kwargs)
|
values, fields_set, validation_error = pydantic.validate_model(self, kwargs) # type: ignore
|
||||||
if validation_error and not pk_only:
|
if validation_error and not pk_only:
|
||||||
raise validation_error
|
raise validation_error
|
||||||
|
|
||||||
@ -134,13 +134,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
) -> Optional[Union["Model", List["Model"]]]:
|
) -> Optional[Union["Model", List["Model"]]]:
|
||||||
if item in self._orm:
|
if item in self._orm:
|
||||||
return self._orm.get(item)
|
return self._orm.get(item)
|
||||||
|
return None
|
||||||
|
|
||||||
def __eq__(self, other: "Model") -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if isinstance(other, NewBaseModel):
|
if isinstance(other, NewBaseModel):
|
||||||
return self.__same__(other)
|
return self.__same__(other)
|
||||||
return super().__eq__(other) # pragma no cover
|
return super().__eq__(other) # pragma no cover
|
||||||
|
|
||||||
def __same__(self, other: "Model") -> bool:
|
def __same__(self, other: "NewBaseModel") -> bool:
|
||||||
return (
|
return (
|
||||||
self._orm_id == other._orm_id
|
self._orm_id == other._orm_id
|
||||||
or self.dict() == other.dict()
|
or self.dict() == other.dict()
|
||||||
@ -205,19 +206,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
dict_instance[field] = None
|
dict_instance[field] = None
|
||||||
return dict_instance
|
return dict_instance
|
||||||
|
|
||||||
def from_dict(self, value_dict: Dict) -> "Model":
|
def from_dict(self, value_dict: Dict) -> "NewBaseModel":
|
||||||
for key, value in value_dict.items():
|
for key, value in value_dict.items():
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
|
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, Dict]:
|
||||||
if not self._is_conversion_to_json_needed(column_name):
|
if not self._is_conversion_to_json_needed(column_name):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
condition = (
|
condition = (
|
||||||
isinstance(value, str) if op == "loads" else not isinstance(value, str)
|
isinstance(value, str) if op == "loads" else not isinstance(value, str)
|
||||||
)
|
)
|
||||||
operand = json.loads if op == "loads" else json.dumps
|
operand: Callable[[Any], Any] = json.loads if op == "loads" else json.dumps
|
||||||
|
|
||||||
if condition:
|
if condition:
|
||||||
try:
|
try:
|
||||||
@ -227,4 +228,4 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def _is_conversion_to_json_needed(self, column_name: str) -> bool:
|
def _is_conversion_to_json_needed(self, column_name: str) -> bool:
|
||||||
return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json
|
return self.Meta.model_fields[column_name].__type__ == pydantic.Json
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
@ -118,7 +118,7 @@ class QueryClause:
|
|||||||
|
|
||||||
def _determine_filter_target_table(
|
def _determine_filter_target_table(
|
||||||
self, related_parts: List[str], select_related: List[str]
|
self, related_parts: List[str], select_related: List[str]
|
||||||
) -> Tuple[List[str], str, "Model"]:
|
) -> Tuple[List[str], str, Type["Model"]]:
|
||||||
|
|
||||||
table_prefix = ""
|
table_prefix = ""
|
||||||
model_cls = self.model_cls
|
model_cls = self.model_cls
|
||||||
@ -168,9 +168,7 @@ class QueryClause:
|
|||||||
return clause
|
return clause
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _escape_characters_in_clause(
|
def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]:
|
||||||
op: str, value: Union[str, "Model"]
|
|
||||||
) -> Tuple[str, bool]:
|
|
||||||
has_escaped_character = False
|
has_escaped_character = False
|
||||||
|
|
||||||
if op not in [
|
if op not in [
|
||||||
|
|||||||
@ -22,8 +22,8 @@ class SqlJoin:
|
|||||||
self,
|
self,
|
||||||
used_aliases: List,
|
used_aliases: List,
|
||||||
select_from: sqlalchemy.sql.select,
|
select_from: sqlalchemy.sql.select,
|
||||||
order_bys: List,
|
order_bys: List[sqlalchemy.sql.elements.TextClause],
|
||||||
columns: List,
|
columns: List[sqlalchemy.Column],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.used_aliases = used_aliases
|
self.used_aliases = used_aliases
|
||||||
self.select_from = select_from
|
self.select_from = select_from
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
|
|
||||||
class LimitQuery:
|
class LimitQuery:
|
||||||
def __init__(self, limit_count: int) -> None:
|
def __init__(self, limit_count: Optional[int]) -> None:
|
||||||
self.limit_count = limit_count
|
self.limit_count = limit_count
|
||||||
|
|
||||||
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:
|
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
|
|
||||||
class OffsetQuery:
|
class OffsetQuery:
|
||||||
def __init__(self, query_offset: int) -> None:
|
def __init__(self, query_offset: Optional[int]) -> None:
|
||||||
self.query_offset = query_offset
|
self.query_offset = query_offset
|
||||||
|
|
||||||
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:
|
def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, TYPE_CHECKING, Tuple, Type
|
from typing import List, TYPE_CHECKING, Tuple, Type, Optional
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
@ -18,8 +18,8 @@ class Query:
|
|||||||
filter_clauses: List,
|
filter_clauses: List,
|
||||||
exclude_clauses: List,
|
exclude_clauses: List,
|
||||||
select_related: List,
|
select_related: List,
|
||||||
limit_count: int,
|
limit_count: Optional[int],
|
||||||
offset: int,
|
offset: Optional[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.query_offset = offset
|
self.query_offset = offset
|
||||||
self.limit_count = limit_count
|
self.limit_count = limit_count
|
||||||
@ -30,11 +30,11 @@ class Query:
|
|||||||
self.model_cls = model_cls
|
self.model_cls = model_cls
|
||||||
self.table = self.model_cls.Meta.table
|
self.table = self.model_cls.Meta.table
|
||||||
|
|
||||||
self.used_aliases = []
|
self.used_aliases: List[str] = []
|
||||||
|
|
||||||
self.select_from = None
|
self.select_from: List[str] = []
|
||||||
self.columns = None
|
self.columns = [sqlalchemy.Column]
|
||||||
self.order_bys = None
|
self.order_bys: List[sqlalchemy.sql.elements.TextClause] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefixed_pk_name(self) -> str:
|
def prefixed_pk_name(self) -> str:
|
||||||
@ -89,7 +89,7 @@ class Query:
|
|||||||
return expr
|
return expr
|
||||||
|
|
||||||
def _reset_query_parameters(self) -> None:
|
def _reset_query_parameters(self) -> None:
|
||||||
self.select_from = None
|
self.select_from = []
|
||||||
self.columns = None
|
self.columns = []
|
||||||
self.order_bys = None
|
self.order_bys = []
|
||||||
self.used_aliases = []
|
self.used_aliases = []
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Mapping, TYPE_CHECKING, Tuple, Type, Union
|
from typing import Any, List, Mapping, TYPE_CHECKING, Type, Union, Optional
|
||||||
|
|
||||||
import databases
|
import databases
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
@ -13,17 +13,18 @@ from ormar.queryset.query import Query
|
|||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
from ormar.models.metaclass import ModelMeta
|
||||||
|
|
||||||
|
|
||||||
class QuerySet:
|
class QuerySet:
|
||||||
def __init__( # noqa CFQ002
|
def __init__( # noqa CFQ002
|
||||||
self,
|
self,
|
||||||
model_cls: Type["Model"] = None,
|
model_cls: Type["Model"] = None,
|
||||||
filter_clauses: List = None,
|
filter_clauses: List = None,
|
||||||
exclude_clauses: List = None,
|
exclude_clauses: List = None,
|
||||||
select_related: List = None,
|
select_related: List = None,
|
||||||
limit_count: int = None,
|
limit_count: int = None,
|
||||||
offset: int = None,
|
offset: int = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_cls = model_cls
|
self.model_cls = model_cls
|
||||||
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
||||||
@ -36,47 +37,60 @@ class QuerySet:
|
|||||||
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet":
|
||||||
return self.__class__(model_cls=owner)
|
return self.__class__(model_cls=owner)
|
||||||
|
|
||||||
def _process_query_result_rows(self, rows: List[Mapping]) -> List["Model"]:
|
@property
|
||||||
|
def model_meta(self) -> "ModelMeta":
|
||||||
|
if not self.model_cls: # pragma nocover
|
||||||
|
raise ValueError("Model class of QuerySet is not initialized")
|
||||||
|
return self.model_cls.Meta
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> Type["Model"]:
|
||||||
|
if not self.model_cls: # pragma nocover
|
||||||
|
raise ValueError("Model class of QuerySet is not initialized")
|
||||||
|
return self.model_cls
|
||||||
|
|
||||||
|
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
|
||||||
result_rows = [
|
result_rows = [
|
||||||
self.model_cls.from_row(row, select_related=self._select_related)
|
self.model.from_row(row, select_related=self._select_related)
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
rows = self.model_cls.merge_instances_list(result_rows)
|
if result_rows:
|
||||||
return rows
|
return self.model.merge_instances_list(result_rows) # type: ignore
|
||||||
|
return result_rows
|
||||||
|
|
||||||
def _populate_default_values(self, new_kwargs: dict) -> dict:
|
def _populate_default_values(self, new_kwargs: dict) -> dict:
|
||||||
for field_name, field in self.model_cls.Meta.model_fields.items():
|
for field_name, field in self.model_meta.model_fields.items():
|
||||||
if field_name not in new_kwargs and field.has_default():
|
if field_name not in new_kwargs and field.has_default():
|
||||||
new_kwargs[field_name] = field.get_default()
|
new_kwargs[field_name] = field.get_default()
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
|
||||||
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
|
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
|
||||||
pkname = self.model_cls.Meta.pkname
|
pkname = self.model_meta.pkname
|
||||||
pk = self.model_cls.Meta.model_fields[pkname]
|
pk = self.model_meta.model_fields[pkname]
|
||||||
if new_kwargs.get(pkname, ormar.Undefined) is None and (
|
if new_kwargs.get(pkname, ormar.Undefined) is None and (
|
||||||
pk.nullable or pk.autoincrement
|
pk.nullable or pk.autoincrement
|
||||||
):
|
):
|
||||||
del new_kwargs[pkname]
|
del new_kwargs[pkname]
|
||||||
return new_kwargs
|
return new_kwargs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_single_result_rows_count(rows: List["Model"]) -> None:
|
def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None:
|
||||||
if not rows:
|
if not rows or rows[0] is None:
|
||||||
raise NoMatch()
|
raise NoMatch()
|
||||||
if len(rows) > 1:
|
if len(rows) > 1:
|
||||||
raise MultipleMatches()
|
raise MultipleMatches()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database(self) -> databases.Database:
|
def database(self) -> databases.Database:
|
||||||
return self.model_cls.Meta.database
|
return self.model_meta.database
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def table(self) -> sqlalchemy.Table:
|
def table(self) -> sqlalchemy.Table:
|
||||||
return self.model_cls.Meta.table
|
return self.model_meta.table
|
||||||
|
|
||||||
def build_select_expression(self) -> sqlalchemy.sql.select:
|
def build_select_expression(self) -> sqlalchemy.sql.select:
|
||||||
qry = Query(
|
qry = Query(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
exclude_clauses=self.exclude_clauses,
|
exclude_clauses=self.exclude_clauses,
|
||||||
@ -89,7 +103,7 @@ class QuerySet:
|
|||||||
|
|
||||||
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
|
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||||
qryclause = QueryClause(
|
qryclause = QueryClause(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
)
|
)
|
||||||
@ -102,7 +116,7 @@ class QuerySet:
|
|||||||
filter_clauses = filter_clauses
|
filter_clauses = filter_clauses
|
||||||
|
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model,
|
||||||
filter_clauses=filter_clauses,
|
filter_clauses=filter_clauses,
|
||||||
exclude_clauses=exclude_clauses,
|
exclude_clauses=exclude_clauses,
|
||||||
select_related=select_related,
|
select_related=select_related,
|
||||||
@ -113,13 +127,13 @@ class QuerySet:
|
|||||||
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003
|
||||||
return self.filter(_exclude=True, **kwargs)
|
return self.filter(_exclude=True, **kwargs)
|
||||||
|
|
||||||
def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet":
|
def select_related(self, related: Union[List, str]) -> "QuerySet":
|
||||||
if not isinstance(related, (list, tuple)):
|
if not isinstance(related, list):
|
||||||
related = [related]
|
related = [related]
|
||||||
|
|
||||||
related = list(set(list(self._select_related) + related))
|
related = list(set(list(self._select_related) + related))
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
exclude_clauses=self.exclude_clauses,
|
exclude_clauses=self.exclude_clauses,
|
||||||
select_related=related,
|
select_related=related,
|
||||||
@ -138,7 +152,7 @@ class QuerySet:
|
|||||||
return await self.database.fetch_val(expr)
|
return await self.database.fetch_val(expr)
|
||||||
|
|
||||||
async def update(self, each: bool = False, **kwargs: Any) -> int:
|
async def update(self, each: bool = False, **kwargs: Any) -> int:
|
||||||
self_fields = self.model_cls.extract_db_own_fields()
|
self_fields = self.model.extract_db_own_fields()
|
||||||
updates = {k: v for k, v in kwargs.items() if k in self_fields}
|
updates = {k: v for k, v in kwargs.items() if k in self_fields}
|
||||||
if not each and not self.filter_clauses:
|
if not each and not self.filter_clauses:
|
||||||
raise QueryDefinitionError(
|
raise QueryDefinitionError(
|
||||||
@ -165,7 +179,7 @@ class QuerySet:
|
|||||||
|
|
||||||
def limit(self, limit_count: int) -> "QuerySet":
|
def limit(self, limit_count: int) -> "QuerySet":
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
exclude_clauses=self.exclude_clauses,
|
exclude_clauses=self.exclude_clauses,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
@ -175,7 +189,7 @@ class QuerySet:
|
|||||||
|
|
||||||
def offset(self, offset: int) -> "QuerySet":
|
def offset(self, offset: int) -> "QuerySet":
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
model_cls=self.model_cls,
|
model_cls=self.model,
|
||||||
filter_clauses=self.filter_clauses,
|
filter_clauses=self.filter_clauses,
|
||||||
exclude_clauses=self.exclude_clauses,
|
exclude_clauses=self.exclude_clauses,
|
||||||
select_related=self._select_related,
|
select_related=self._select_related,
|
||||||
@ -189,7 +203,7 @@ class QuerySet:
|
|||||||
|
|
||||||
rows = await self.limit(1).all()
|
rows = await self.limit(1).all()
|
||||||
self.check_single_result_rows_count(rows)
|
self.check_single_result_rows_count(rows)
|
||||||
return rows[0]
|
return rows[0] # type: ignore
|
||||||
|
|
||||||
async def get(self, **kwargs: Any) -> "Model":
|
async def get(self, **kwargs: Any) -> "Model":
|
||||||
if kwargs:
|
if kwargs:
|
||||||
@ -200,9 +214,9 @@ class QuerySet:
|
|||||||
expr = expr.limit(2)
|
expr = expr.limit(2)
|
||||||
|
|
||||||
rows = await self.database.fetch_all(expr)
|
rows = await self.database.fetch_all(expr)
|
||||||
rows = self._process_query_result_rows(rows)
|
processed_rows = self._process_query_result_rows(rows)
|
||||||
self.check_single_result_rows_count(rows)
|
self.check_single_result_rows_count(processed_rows)
|
||||||
return rows[0]
|
return processed_rows[0] # type: ignore
|
||||||
|
|
||||||
async def get_or_create(self, **kwargs: Any) -> "Model":
|
async def get_or_create(self, **kwargs: Any) -> "Model":
|
||||||
try:
|
try:
|
||||||
@ -211,7 +225,7 @@ class QuerySet:
|
|||||||
return await self.create(**kwargs)
|
return await self.create(**kwargs)
|
||||||
|
|
||||||
async def update_or_create(self, **kwargs: Any) -> "Model":
|
async def update_or_create(self, **kwargs: Any) -> "Model":
|
||||||
pk_name = self.model_cls.Meta.pkname
|
pk_name = self.model_meta.pkname
|
||||||
if "pk" in kwargs:
|
if "pk" in kwargs:
|
||||||
kwargs[pk_name] = kwargs.pop("pk")
|
kwargs[pk_name] = kwargs.pop("pk")
|
||||||
if pk_name not in kwargs or kwargs.get(pk_name) is None:
|
if pk_name not in kwargs or kwargs.get(pk_name) is None:
|
||||||
@ -219,7 +233,7 @@ class QuerySet:
|
|||||||
model = await self.get(pk=kwargs[pk_name])
|
model = await self.get(pk=kwargs[pk_name])
|
||||||
return await model.update(**kwargs)
|
return await model.update(**kwargs)
|
||||||
|
|
||||||
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
|
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
|
||||||
if kwargs:
|
if kwargs:
|
||||||
return await self.filter(**kwargs).all()
|
return await self.filter(**kwargs).all()
|
||||||
|
|
||||||
@ -233,20 +247,20 @@ class QuerySet:
|
|||||||
|
|
||||||
new_kwargs = dict(**kwargs)
|
new_kwargs = dict(**kwargs)
|
||||||
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
|
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
|
||||||
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
|
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
||||||
new_kwargs = self._populate_default_values(new_kwargs)
|
new_kwargs = self._populate_default_values(new_kwargs)
|
||||||
|
|
||||||
expr = self.table.insert()
|
expr = self.table.insert()
|
||||||
expr = expr.values(**new_kwargs)
|
expr = expr.values(**new_kwargs)
|
||||||
|
|
||||||
# Execute the insert, and return a new model instance.
|
# Execute the insert, and return a new model instance.
|
||||||
instance = self.model_cls(**kwargs)
|
instance = self.model(**kwargs)
|
||||||
pk = await self.database.execute(expr)
|
pk = await self.database.execute(expr)
|
||||||
pk_name = self.model_cls.Meta.pkname
|
pk_name = self.model_meta.pkname
|
||||||
if pk_name not in kwargs and pk_name in new_kwargs:
|
if pk_name not in kwargs and pk_name in new_kwargs:
|
||||||
instance.pk = new_kwargs[self.model_cls.Meta.pkname]
|
instance.pk = new_kwargs[self.model_meta.pkname]
|
||||||
if pk and isinstance(pk, self.model_cls.pk_type()):
|
if pk and isinstance(pk, self.model.pk_type()):
|
||||||
setattr(instance, self.model_cls.Meta.pkname, pk)
|
setattr(instance, self.model_meta.pkname, pk)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
async def bulk_create(self, objects: List["Model"]) -> None:
|
async def bulk_create(self, objects: List["Model"]) -> None:
|
||||||
@ -254,7 +268,7 @@ class QuerySet:
|
|||||||
for objt in objects:
|
for objt in objects:
|
||||||
new_kwargs = objt.dict()
|
new_kwargs = objt.dict()
|
||||||
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
|
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
|
||||||
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
|
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
||||||
new_kwargs = self._populate_default_values(new_kwargs)
|
new_kwargs = self._populate_default_values(new_kwargs)
|
||||||
ready_objects.append(new_kwargs)
|
ready_objects.append(new_kwargs)
|
||||||
|
|
||||||
@ -262,13 +276,15 @@ class QuerySet:
|
|||||||
await self.database.execute_many(expr, ready_objects)
|
await self.database.execute_many(expr, ready_objects)
|
||||||
|
|
||||||
async def bulk_update(
|
async def bulk_update(
|
||||||
self, objects: List["Model"], columns: List[str] = None
|
self, objects: List["Model"], columns: List[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
ready_objects = []
|
ready_objects = []
|
||||||
pk_name = self.model_cls.Meta.pkname
|
pk_name = self.model_meta.pkname
|
||||||
if not columns:
|
if not columns:
|
||||||
columns = self.model_cls.extract_db_own_fields().union(
|
columns = list(
|
||||||
self.model_cls.extract_related_names()
|
self.model.extract_db_own_fields().union(
|
||||||
|
self.model.extract_related_names()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if pk_name not in columns:
|
if pk_name not in columns:
|
||||||
@ -279,13 +295,13 @@ class QuerySet:
|
|||||||
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
|
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
|
||||||
raise QueryDefinitionError(
|
raise QueryDefinitionError(
|
||||||
"You cannot update unsaved objects. "
|
"You cannot update unsaved objects. "
|
||||||
f"{self.model_cls.__name__} has to have {pk_name} filled."
|
f"{self.model.__name__} has to have {pk_name} filled."
|
||||||
)
|
)
|
||||||
new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
|
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
|
||||||
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
|
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
|
||||||
ready_objects.append(new_kwargs)
|
ready_objects.append(new_kwargs)
|
||||||
|
|
||||||
pk_column = self.model_cls.Meta.table.c.get(pk_name)
|
pk_column = self.model_meta.table.c.get(pk_name)
|
||||||
expr = self.table.update().where(pk_column == bindparam("new_" + pk_name))
|
expr = self.table.update().where(pk_column == bindparam("new_" + pk_name))
|
||||||
expr = expr.values(
|
expr = expr.values(
|
||||||
**{k: bindparam("new_" + k) for k in columns if k != pk_name}
|
**{k: bindparam("new_" + k) for k in columns if k != pk_name}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from random import choices
|
from random import choices
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
@ -14,7 +14,7 @@ def get_table_alias() -> str:
|
|||||||
|
|
||||||
class AliasManager:
|
class AliasManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._aliases = dict()
|
self._aliases: Dict[str, str] = dict()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
|
||||||
|
|||||||
@ -13,8 +13,8 @@ class QuerysetProxy:
|
|||||||
relation: "Relation"
|
relation: "Relation"
|
||||||
|
|
||||||
def __init__(self, relation: "Relation") -> None:
|
def __init__(self, relation: "Relation") -> None:
|
||||||
self.relation = relation
|
self.relation: Relation = relation
|
||||||
self.queryset = None
|
self.queryset: "QuerySet"
|
||||||
|
|
||||||
def _assign_child_to_parent(self, child: "Model") -> None:
|
def _assign_child_to_parent(self, child: "Model") -> None:
|
||||||
owner = self.relation._owner
|
owner = self.relation._owner
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from ormar.relations.relation_proxy import RelationProxy
|
|||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
from ormar.relations import RelationsManager
|
from ormar.relations import RelationsManager
|
||||||
|
from ormar.models import NewBaseModel
|
||||||
|
|
||||||
|
|
||||||
class RelationType(Enum):
|
class RelationType(Enum):
|
||||||
@ -19,24 +20,26 @@ class RelationType(Enum):
|
|||||||
|
|
||||||
class Relation:
|
class Relation:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
manager: "RelationsManager",
|
manager: "RelationsManager",
|
||||||
type_: RelationType,
|
type_: RelationType,
|
||||||
to: Type["Model"],
|
to: Type["Model"],
|
||||||
through: Type["Model"] = None,
|
through: Type["Model"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
self._owner = manager.owner
|
self._owner: "Model" = manager.owner
|
||||||
self._type = type_
|
self._type: RelationType = type_
|
||||||
self.to = to
|
self.to: Type["Model"] = to
|
||||||
self.through = through
|
self.through: Optional[Type["Model"]] = through
|
||||||
self.related_models = (
|
self.related_models: Optional[Union[RelationProxy, "Model"]] = (
|
||||||
RelationProxy(relation=self)
|
RelationProxy(relation=self)
|
||||||
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
|
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
def _find_existing(self, child: "Model") -> Optional[int]:
|
def _find_existing(self, child: "Model") -> Optional[int]:
|
||||||
|
if not isinstance(self.related_models, RelationProxy): # pragma nocover
|
||||||
|
raise ValueError("Cannot find existing models in parent relation type")
|
||||||
for ind, relation_child in enumerate(self.related_models[:]):
|
for ind, relation_child in enumerate(self.related_models[:]):
|
||||||
try:
|
try:
|
||||||
if relation_child == child:
|
if relation_child == child:
|
||||||
@ -52,7 +55,7 @@ class Relation:
|
|||||||
self._owner.__dict__[relation_name] = child
|
self._owner.__dict__[relation_name] = child
|
||||||
else:
|
else:
|
||||||
if self._find_existing(child) is None:
|
if self._find_existing(child) is None:
|
||||||
self.related_models.append(child)
|
self.related_models.append(child) # type: ignore
|
||||||
rel = self._owner.__dict__.get(relation_name, [])
|
rel = self._owner.__dict__.get(relation_name, [])
|
||||||
rel = rel or []
|
rel = rel or []
|
||||||
if not isinstance(rel, list):
|
if not isinstance(rel, list):
|
||||||
@ -60,19 +63,19 @@ class Relation:
|
|||||||
rel.append(child)
|
rel.append(child)
|
||||||
self._owner.__dict__[relation_name] = rel
|
self._owner.__dict__[relation_name] = rel
|
||||||
|
|
||||||
def remove(self, child: "Model") -> None:
|
def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
|
||||||
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
||||||
if self._type == RelationType.PRIMARY:
|
if self._type == RelationType.PRIMARY:
|
||||||
if self.related_models.__same__(child):
|
if self.related_models == child:
|
||||||
self.related_models = None
|
self.related_models = None
|
||||||
del self._owner.__dict__[relation_name]
|
del self._owner.__dict__[relation_name]
|
||||||
else:
|
else:
|
||||||
position = self._find_existing(child)
|
position = self._find_existing(child)
|
||||||
if position is not None:
|
if position is not None:
|
||||||
self.related_models.pop(position)
|
self.related_models.pop(position) # type: ignore
|
||||||
del self._owner.__dict__[relation_name][position]
|
del self._owner.__dict__[relation_name][position]
|
||||||
|
|
||||||
def get(self) -> Union[List["Model"], "Model"]:
|
def get(self) -> Optional[Union[List["Model"], "Model"]]:
|
||||||
return self.related_models
|
return self.related_models
|
||||||
|
|
||||||
def __repr__(self) -> str: # pragma no cover
|
def __repr__(self) -> str: # pragma no cover
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import List, Optional, TYPE_CHECKING, Type, Union
|
from typing import List, Optional, TYPE_CHECKING, Type, Union, Dict
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
|
||||||
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.fields.many_to_many import ManyToManyField
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
from ormar.relations.relation import Relation, RelationType
|
from ormar.relations.relation import Relation, RelationType
|
||||||
@ -11,25 +12,28 @@ from ormar.relations.utils import (
|
|||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
from ormar.models import NewBaseModel
|
||||||
|
|
||||||
|
|
||||||
class RelationsManager:
|
class RelationsManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
|
self,
|
||||||
|
related_fields: List[Type[ForeignKeyField]] = None,
|
||||||
|
owner: "NewBaseModel" = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.owner = proxy(owner)
|
self.owner = proxy(owner)
|
||||||
self._related_fields = related_fields or []
|
self._related_fields = related_fields or []
|
||||||
self._related_names = [field.name for field in self._related_fields]
|
self._related_names = [field.name for field in self._related_fields]
|
||||||
self._relations = dict()
|
self._relations: Dict[str, Relation] = dict()
|
||||||
for field in self._related_fields:
|
for field in self._related_fields:
|
||||||
self._add_relation(field)
|
self._add_relation(field)
|
||||||
|
|
||||||
def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType:
|
def _get_relation_type(self, field: Type[BaseField]) -> RelationType:
|
||||||
if issubclass(field, ManyToManyField):
|
if issubclass(field, ManyToManyField):
|
||||||
return RelationType.MULTIPLE
|
return RelationType.MULTIPLE
|
||||||
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
|
return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE
|
||||||
|
|
||||||
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
|
def _add_relation(self, field: Type[BaseField]) -> None:
|
||||||
self._relations[field.name] = Relation(
|
self._relations[field.name] = Relation(
|
||||||
manager=self,
|
manager=self,
|
||||||
type_=self._get_relation_type(field),
|
type_=self._get_relation_type(field),
|
||||||
@ -44,15 +48,17 @@ class RelationsManager:
|
|||||||
relation = self._relations.get(name, None)
|
relation = self._relations.get(name, None)
|
||||||
if relation is not None:
|
if relation is not None:
|
||||||
return relation.get()
|
return relation.get()
|
||||||
|
return None # pragma nocover
|
||||||
|
|
||||||
def _get(self, name: str) -> Optional[Relation]:
|
def _get(self, name: str) -> Optional[Relation]:
|
||||||
relation = self._relations.get(name, None)
|
relation = self._relations.get(name, None)
|
||||||
if relation is not None:
|
if relation is not None:
|
||||||
return relation
|
return relation
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
|
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
|
||||||
to_field = child.resolve_relation_field(child, parent)
|
to_field: Type[BaseField] = child.resolve_relation_field(child, parent)
|
||||||
|
|
||||||
(parent, child, child_name, to_name,) = get_relations_sides_and_names(
|
(parent, child, child_name, to_name,) = get_relations_sides_and_names(
|
||||||
to_field, parent, child, child_name, virtual
|
to_field, parent, child, child_name, virtual
|
||||||
@ -61,18 +67,22 @@ class RelationsManager:
|
|||||||
parent_relation = parent._orm._get(child_name)
|
parent_relation = parent._orm._get(child_name)
|
||||||
if not parent_relation:
|
if not parent_relation:
|
||||||
parent_relation = register_missing_relation(parent, child, child_name)
|
parent_relation = register_missing_relation(parent, child, child_name)
|
||||||
parent_relation.add(child)
|
parent_relation.add(child) # type: ignore
|
||||||
child._orm._get(to_name).add(parent)
|
|
||||||
|
|
||||||
def remove(self, name: str, child: "Model") -> None:
|
child_relation = child._orm._get(to_name)
|
||||||
|
if child_relation:
|
||||||
|
child_relation.add(parent)
|
||||||
|
|
||||||
|
def remove(self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
|
||||||
relation = self._get(name)
|
relation = self._get(name)
|
||||||
relation.remove(child)
|
if relation:
|
||||||
|
relation.remove(child)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
|
def remove_parent(item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model") -> None:
|
||||||
related_model = name
|
related_model = name
|
||||||
name = item.resolve_relation_name(item, related_model)
|
rel_name = item.resolve_relation_name(item, related_model)
|
||||||
if name in item._orm:
|
if rel_name in item._orm:
|
||||||
relation_name = item.resolve_relation_name(related_model, item)
|
relation_name = item.resolve_relation_name(related_model, item)
|
||||||
item._orm.remove(name, related_model)
|
item._orm.remove(rel_name, related_model)
|
||||||
related_model._orm.remove(relation_name, item)
|
related_model._orm.remove(relation_name, item)
|
||||||
|
|||||||
@ -13,22 +13,30 @@ if TYPE_CHECKING: # pragma no cover
|
|||||||
class RelationProxy(list):
|
class RelationProxy(list):
|
||||||
def __init__(self, relation: "Relation") -> None:
|
def __init__(self, relation: "Relation") -> None:
|
||||||
super(RelationProxy, self).__init__()
|
super(RelationProxy, self).__init__()
|
||||||
self.relation = relation
|
self.relation: Relation = relation
|
||||||
self._owner = self.relation.manager.owner
|
self._owner: "Model" = self.relation.manager.owner
|
||||||
self.queryset_proxy = QuerysetProxy(relation=self.relation)
|
self.queryset_proxy = QuerysetProxy(relation=self.relation)
|
||||||
|
|
||||||
def __getattribute__(self, item: str) -> Any:
|
def __getattribute__(self, item: str) -> Any:
|
||||||
if item in ["count", "clear"]:
|
if item in ["count", "clear"]:
|
||||||
if not self.queryset_proxy.queryset:
|
self._initialize_queryset()
|
||||||
self.queryset_proxy.queryset = self._set_queryset()
|
|
||||||
return getattr(self.queryset_proxy, item)
|
return getattr(self.queryset_proxy, item)
|
||||||
return super().__getattribute__(item)
|
return super().__getattribute__(item)
|
||||||
|
|
||||||
def __getattr__(self, item: str) -> Any:
|
def __getattr__(self, item: str) -> Any:
|
||||||
if not self.queryset_proxy.queryset:
|
self._initialize_queryset()
|
||||||
self.queryset_proxy.queryset = self._set_queryset()
|
|
||||||
return getattr(self.queryset_proxy, item)
|
return getattr(self.queryset_proxy, item)
|
||||||
|
|
||||||
|
def _initialize_queryset(self) -> None:
|
||||||
|
if not self._check_if_queryset_is_initialized():
|
||||||
|
self.queryset_proxy.queryset = self._set_queryset()
|
||||||
|
|
||||||
|
def _check_if_queryset_is_initialized(self) -> bool:
|
||||||
|
return (
|
||||||
|
hasattr(self.queryset_proxy, "queryset")
|
||||||
|
and self.queryset_proxy.queryset is not None
|
||||||
|
)
|
||||||
|
|
||||||
def _set_queryset(self) -> "QuerySet":
|
def _set_queryset(self) -> "QuerySet":
|
||||||
owner_table = self.relation._owner.Meta.tablename
|
owner_table = self.relation._owner.Meta.tablename
|
||||||
pkname = self.relation._owner.Meta.pkname
|
pkname = self.relation._owner.Meta.pkname
|
||||||
@ -45,10 +53,15 @@ class RelationProxy(list):
|
|||||||
)
|
)
|
||||||
return queryset
|
return queryset
|
||||||
|
|
||||||
async def remove(self, item: "Model") -> None:
|
async def remove(self, item: "Model") -> None: # type: ignore
|
||||||
super().remove(item)
|
super().remove(item)
|
||||||
rel_name = item.resolve_relation_name(item, self._owner)
|
rel_name = item.resolve_relation_name(item, self._owner)
|
||||||
item._orm._get(rel_name).remove(self._owner)
|
relation = item._orm._get(rel_name)
|
||||||
|
if relation is None: # pragma nocover
|
||||||
|
raise ValueError(
|
||||||
|
f"{self._owner.get_name()} does not have relation {rel_name}"
|
||||||
|
)
|
||||||
|
relation.remove(self._owner)
|
||||||
if self.relation._type == ormar.RelationType.MULTIPLE:
|
if self.relation._type == ormar.RelationType.MULTIPLE:
|
||||||
await self.queryset_proxy.delete_through_instance(item)
|
await self.queryset_proxy.delete_through_instance(item)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from typing import TYPE_CHECKING, Tuple, Type
|
from typing import TYPE_CHECKING, Tuple, Type, Optional
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
|
||||||
import ormar
|
import ormar
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.many_to_many import ManyToManyField
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
from ormar.relations import Relation
|
from ormar.relations import Relation
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ if TYPE_CHECKING: # pragma no cover
|
|||||||
|
|
||||||
def register_missing_relation(
|
def register_missing_relation(
|
||||||
parent: "Model", child: "Model", child_name: str
|
parent: "Model", child: "Model", child_name: str
|
||||||
) -> Relation:
|
) -> Optional[Relation]:
|
||||||
ormar.models.expand_reverse_relationships(child.__class__)
|
ormar.models.expand_reverse_relationships(child.__class__)
|
||||||
name = parent.resolve_relation_name(parent, child)
|
name = parent.resolve_relation_name(parent, child)
|
||||||
field = parent.Meta.model_fields[name]
|
field = parent.Meta.model_fields[name]
|
||||||
@ -22,7 +22,7 @@ def register_missing_relation(
|
|||||||
|
|
||||||
|
|
||||||
def get_relations_sides_and_names(
|
def get_relations_sides_and_names(
|
||||||
to_field: Type[ForeignKeyField],
|
to_field: Type[BaseField],
|
||||||
parent: "Model",
|
parent: "Model",
|
||||||
child: "Model",
|
child: "Model",
|
||||||
child_name: str,
|
child_name: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user