allow change to build in type hints
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
|
||||
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import Field, typing
|
||||
from pydantic.fields import FieldInfo
|
||||
@ -11,8 +12,9 @@ if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models import NewBaseModel
|
||||
|
||||
|
||||
class BaseField:
|
||||
class BaseField(FieldInfo):
|
||||
__type__ = None
|
||||
__pydantic_type__ = None
|
||||
|
||||
column_type: sqlalchemy.Column
|
||||
constraints: List = []
|
||||
@ -32,6 +34,28 @@ class BaseField:
|
||||
default: Any
|
||||
server_default: Any
|
||||
|
||||
@classmethod
|
||||
def is_valid_field_info_field(cls, field_name: str) -> bool:
|
||||
return (
|
||||
field_name not in ["default", "default_factory"]
|
||||
and not field_name.startswith("__")
|
||||
and hasattr(cls, field_name)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_to_pydantic_field_info(cls, allow_null: bool = False) -> FieldInfo:
|
||||
base = cls.default_value()
|
||||
if base is None:
|
||||
base = (
|
||||
FieldInfo(default=None)
|
||||
if (cls.nullable or allow_null)
|
||||
else FieldInfo(default=pydantic.fields.Undefined)
|
||||
)
|
||||
for attr_name in FieldInfo.__dict__.keys():
|
||||
if cls.is_valid_field_info_field(attr_name):
|
||||
setattr(base, attr_name, cls.__dict__.get(attr_name))
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]:
|
||||
if cls.is_auto_primary_key():
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Generator, List, Optional, TYPE_CHECKING, Type, Union
|
||||
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import UniqueConstraint
|
||||
@ -39,8 +39,15 @@ def ForeignKey( # noqa CFQ002
|
||||
ondelete: str = None,
|
||||
) -> Type["ForeignKeyField"]:
|
||||
fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname)
|
||||
to_field = to.__fields__[to.Meta.pkname]
|
||||
to_field = to.Meta.model_fields[to.Meta.pkname]
|
||||
__type__ = (
|
||||
Union[to_field.__type__, to]
|
||||
if not nullable
|
||||
else Optional[Union[to_field.__type__, to]]
|
||||
)
|
||||
namespace = dict(
|
||||
__type__=__type__,
|
||||
__pydantic_type__=__type__,
|
||||
to=to,
|
||||
name=name,
|
||||
nullable=nullable,
|
||||
@ -50,7 +57,7 @@ def ForeignKey( # noqa CFQ002
|
||||
)
|
||||
],
|
||||
unique=unique,
|
||||
column_type=to_field.type_.column_type,
|
||||
column_type=to_field.column_type,
|
||||
related_name=related_name,
|
||||
virtual=virtual,
|
||||
primary_key=False,
|
||||
@ -58,7 +65,6 @@ def ForeignKey( # noqa CFQ002
|
||||
pydantic_only=False,
|
||||
default=None,
|
||||
server_default=None,
|
||||
__pydantic_model__=to,
|
||||
)
|
||||
|
||||
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
|
||||
@ -70,14 +76,6 @@ class ForeignKeyField(BaseField):
|
||||
related_name: str
|
||||
virtual: bool
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator:
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> Any:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _extract_model_from_sequence(
|
||||
cls, value: List, child: "Model", to_register: bool
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, TYPE_CHECKING, Type
|
||||
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
|
||||
|
||||
from ormar.fields import BaseField
|
||||
from ormar.fields.foreign_key import ForeignKeyField
|
||||
@ -15,17 +15,26 @@ def ManyToMany(
|
||||
*,
|
||||
name: str = None,
|
||||
unique: bool = False,
|
||||
related_name: str = None,
|
||||
virtual: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Type["ManyToManyField"]:
|
||||
to_field = to.__fields__[to.Meta.pkname]
|
||||
to_field = to.Meta.model_fields[to.Meta.pkname]
|
||||
related_name = kwargs.pop("related_name", None)
|
||||
nullable = kwargs.pop("nullable", True)
|
||||
__type__ = (
|
||||
Union[to_field.__type__, to, List[to]] # type: ignore
|
||||
if not nullable
|
||||
else Optional[Union[to_field.__type__, to, List[to]]] # type: ignore
|
||||
)
|
||||
namespace = dict(
|
||||
__type__=__type__,
|
||||
__pydantic_type__=__type__,
|
||||
to=to,
|
||||
through=through,
|
||||
name=name,
|
||||
nullable=True,
|
||||
unique=unique,
|
||||
column_type=to_field.type_.column_type,
|
||||
column_type=to_field.column_type,
|
||||
related_name=related_name,
|
||||
virtual=virtual,
|
||||
primary_key=False,
|
||||
@ -33,9 +42,6 @@ def ManyToMany(
|
||||
pydantic_only=False,
|
||||
default=None,
|
||||
server_default=None,
|
||||
__pydantic_model__=to,
|
||||
# __origin__=List,
|
||||
# __args__=[Optional[to]]
|
||||
)
|
||||
|
||||
return type("ManyToMany", (ManyToManyField, BaseField), namespace)
|
||||
@ -43,10 +49,3 @@ def ManyToMany(
|
||||
|
||||
class ManyToManyField(ForeignKeyField):
|
||||
through: Type["Model"]
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict) -> None:
|
||||
field_schema["type"] = "array"
|
||||
field_schema["title"] = cls.name.title()
|
||||
field_schema["definitions"] = {f"{cls.to.__name__}": cls.to.schema()}
|
||||
field_schema["items"] = {"$ref": f"{REF_PREFIX}{cls.to.__name__}"}
|
||||
|
||||
@ -20,8 +20,9 @@ def is_field_nullable(
|
||||
|
||||
|
||||
class ModelFieldFactory:
|
||||
_bases: Any = BaseField
|
||||
_bases: Any = (BaseField,)
|
||||
_type: Any = None
|
||||
_pydantic_type: Any = None
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore
|
||||
cls.validate(**kwargs)
|
||||
@ -32,6 +33,7 @@ class ModelFieldFactory:
|
||||
|
||||
namespace = dict(
|
||||
__type__=cls._type,
|
||||
__pydantic_type__=cls._pydantic_type,
|
||||
name=kwargs.pop("name", None),
|
||||
primary_key=kwargs.pop("primary_key", False),
|
||||
default=default,
|
||||
@ -57,8 +59,8 @@ class ModelFieldFactory:
|
||||
|
||||
|
||||
class String(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedStr, BaseField)
|
||||
_type = str
|
||||
_pydantic_type = pydantic.ConstrainedStr
|
||||
|
||||
def __new__( # type: ignore # noqa CFQ002
|
||||
cls,
|
||||
@ -96,8 +98,8 @@ class String(ModelFieldFactory):
|
||||
|
||||
|
||||
class Integer(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedInt, BaseField)
|
||||
_type = int
|
||||
_pydantic_type = pydantic.ConstrainedInt
|
||||
|
||||
def __new__( # type: ignore
|
||||
cls,
|
||||
@ -131,8 +133,8 @@ class Integer(ModelFieldFactory):
|
||||
|
||||
|
||||
class Text(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedStr, BaseField)
|
||||
_type = str
|
||||
_pydantic_type = pydantic.ConstrainedStr
|
||||
|
||||
def __new__( # type: ignore
|
||||
cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any
|
||||
@ -154,8 +156,8 @@ class Text(ModelFieldFactory):
|
||||
|
||||
|
||||
class Float(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedFloat, BaseField)
|
||||
_type = float
|
||||
_pydantic_type = pydantic.ConstrainedFloat
|
||||
|
||||
def __new__( # type: ignore
|
||||
cls,
|
||||
@ -183,8 +185,8 @@ class Float(ModelFieldFactory):
|
||||
|
||||
|
||||
class Boolean(ModelFieldFactory):
|
||||
_bases = (int, BaseField)
|
||||
_type = bool
|
||||
_pydantic_type = bool
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
@ -192,8 +194,8 @@ class Boolean(ModelFieldFactory):
|
||||
|
||||
|
||||
class DateTime(ModelFieldFactory):
|
||||
_bases = (datetime.datetime, BaseField)
|
||||
_type = datetime.datetime
|
||||
_pydantic_type = datetime.datetime
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
@ -201,8 +203,8 @@ class DateTime(ModelFieldFactory):
|
||||
|
||||
|
||||
class Date(ModelFieldFactory):
|
||||
_bases = (datetime.date, BaseField)
|
||||
_type = datetime.date
|
||||
_pydantic_type = datetime.date
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
@ -210,8 +212,8 @@ class Date(ModelFieldFactory):
|
||||
|
||||
|
||||
class Time(ModelFieldFactory):
|
||||
_bases = (datetime.time, BaseField)
|
||||
_type = datetime.time
|
||||
_pydantic_type = datetime.time
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
@ -219,8 +221,8 @@ class Time(ModelFieldFactory):
|
||||
|
||||
|
||||
class JSON(ModelFieldFactory):
|
||||
_bases = (pydantic.Json, BaseField)
|
||||
_type = pydantic.Json
|
||||
_pydantic_type = pydantic.Json
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
@ -228,8 +230,8 @@ class JSON(ModelFieldFactory):
|
||||
|
||||
|
||||
class BigInteger(Integer):
|
||||
_bases = (pydantic.ConstrainedInt, BaseField)
|
||||
_type = int
|
||||
_pydantic_type = pydantic.ConstrainedInt
|
||||
|
||||
def __new__( # type: ignore
|
||||
cls,
|
||||
@ -263,8 +265,8 @@ class BigInteger(Integer):
|
||||
|
||||
|
||||
class Decimal(ModelFieldFactory):
|
||||
_bases = (pydantic.ConstrainedDecimal, BaseField)
|
||||
_type = decimal.Decimal
|
||||
_pydantic_type = pydantic.ConstrainedDecimal
|
||||
|
||||
def __new__( # type: ignore # noqa CFQ002
|
||||
cls,
|
||||
@ -318,8 +320,8 @@ class Decimal(ModelFieldFactory):
|
||||
|
||||
|
||||
class UUID(ModelFieldFactory):
|
||||
_bases = (uuid.UUID, BaseField)
|
||||
_type = uuid.UUID
|
||||
_pydantic_type = uuid.UUID
|
||||
|
||||
@classmethod
|
||||
def get_column_type(cls, **kwargs: Any) -> Any:
|
||||
|
||||
@ -5,7 +5,8 @@ import databases
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.fields import FieldInfo, ModelField
|
||||
from pydantic.fields import ModelField
|
||||
from pydantic.utils import lenient_issubclass
|
||||
from sqlalchemy.sql.schema import ColumnCollectionConstraint
|
||||
|
||||
import ormar # noqa I100
|
||||
@ -179,44 +180,58 @@ def register_relation_in_alias_manager(
|
||||
|
||||
|
||||
def populate_default_pydantic_field_value(
|
||||
type_: Type[BaseField], field: str, attrs: dict
|
||||
ormar_field: Type[BaseField], field_name: str, attrs: dict
|
||||
) -> dict:
|
||||
def_value = type_.default_value()
|
||||
curr_def_value = attrs.get(field, "NONE")
|
||||
if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
|
||||
attrs[field] = def_value
|
||||
elif curr_def_value == "NONE" and type_.nullable:
|
||||
attrs[field] = FieldInfo(default=None)
|
||||
curr_def_value = attrs.get(field_name, ormar.Undefined)
|
||||
if lenient_issubclass(curr_def_value, ormar.fields.BaseField):
|
||||
curr_def_value = ormar.Undefined
|
||||
if curr_def_value is None:
|
||||
attrs[field_name] = ormar_field.convert_to_pydantic_field_info(allow_null=True)
|
||||
else:
|
||||
attrs[field_name] = ormar_field.convert_to_pydantic_field_info()
|
||||
return attrs
|
||||
|
||||
|
||||
def populate_pydantic_default_values(attrs: Dict) -> Dict:
|
||||
for field, type_ in attrs["__annotations__"].items():
|
||||
if issubclass(type_, BaseField):
|
||||
if type_.name is None:
|
||||
type_.name = field
|
||||
attrs = populate_default_pydantic_field_value(type_, field, attrs)
|
||||
return attrs
|
||||
def check_if_field_annotation_or_value_is_ormar(
|
||||
field: Any, field_name: str, attrs: Dict
|
||||
) -> bool:
|
||||
return lenient_issubclass(field, BaseField) or issubclass(
|
||||
attrs.get(field_name, type), BaseField
|
||||
)
|
||||
|
||||
|
||||
def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
|
||||
def extract_field_from_annotation_or_value(
|
||||
field: Any, field_name: str, attrs: Dict
|
||||
) -> Type[ormar.fields.BaseField]:
|
||||
return field if lenient_issubclass(field, BaseField) else attrs.get(field_name)
|
||||
|
||||
|
||||
def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
|
||||
model_fields = {}
|
||||
for field_name, field in attrs["__annotations__"].items():
|
||||
# ormar fields can be used as annotation or as default value
|
||||
if check_if_field_annotation_or_value_is_ormar(field, field_name, attrs):
|
||||
ormar_field = extract_field_from_annotation_or_value(
|
||||
field, field_name, attrs
|
||||
)
|
||||
if ormar_field.name is None:
|
||||
ormar_field.name = field_name
|
||||
attrs = populate_default_pydantic_field_value(
|
||||
ormar_field, field_name, attrs
|
||||
)
|
||||
model_fields[field_name] = ormar_field
|
||||
attrs["__annotations__"][field_name] = ormar_field.__type__
|
||||
return attrs, model_fields
|
||||
|
||||
|
||||
def extract_annotations_and_default_vals(
|
||||
attrs: dict, bases: Tuple
|
||||
) -> Tuple[Dict, Dict]:
|
||||
attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get(
|
||||
"__annotations__", {}
|
||||
)
|
||||
attrs = populate_pydantic_default_values(attrs)
|
||||
return attrs
|
||||
|
||||
|
||||
def populate_meta_orm_model_fields(
|
||||
attrs: dict, new_model: Type["Model"]
|
||||
) -> Type["Model"]:
|
||||
model_fields = {
|
||||
field_name: field
|
||||
for field_name, field in attrs["__annotations__"].items()
|
||||
if issubclass(field, BaseField)
|
||||
}
|
||||
new_model.Meta.model_fields = model_fields
|
||||
return new_model
|
||||
attrs, model_fields = populate_pydantic_default_values(attrs)
|
||||
return attrs, model_fields
|
||||
|
||||
|
||||
def populate_meta_tablename_columns_and_pk(
|
||||
@ -305,7 +320,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
) -> "ModelMetaclass":
|
||||
attrs["Config"] = get_pydantic_base_orm_config()
|
||||
attrs["__name__"] = name
|
||||
attrs = extract_annotations_and_default_vals(attrs, bases)
|
||||
attrs, model_fields = extract_annotations_and_default_vals(attrs, bases)
|
||||
new_model = super().__new__( # type: ignore
|
||||
mcs, name, bases, attrs
|
||||
)
|
||||
@ -313,7 +328,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
if hasattr(new_model, "Meta"):
|
||||
if not hasattr(new_model.Meta, "constraints"):
|
||||
new_model.Meta.constraints = []
|
||||
new_model = populate_meta_orm_model_fields(attrs, new_model)
|
||||
if not hasattr(new_model.Meta, "model_fields"):
|
||||
new_model.Meta.model_fields = model_fields
|
||||
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
||||
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
|
||||
expand_reverse_relationships(new_model)
|
||||
@ -322,7 +338,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
||||
field_name = new_model.Meta.pkname
|
||||
field = Integer(name=field_name, primary_key=True)
|
||||
attrs["__annotations__"][field_name] = field
|
||||
attrs["__annotations__"][field_name] = Optional[int] # type: ignore
|
||||
populate_default_pydantic_field_value(
|
||||
field, field_name, attrs # type: ignore
|
||||
)
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
import ormar.queryset # noqa I100
|
||||
from ormar.fields.many_to_many import ManyToManyField
|
||||
from ormar.models import NewBaseModel # noqa I100
|
||||
from ormar.models.metaclass import ModelMeta
|
||||
|
||||
|
||||
def group_related_list(list_: List) -> Dict:
|
||||
@ -23,18 +24,30 @@ def group_related_list(list_: List) -> Dict:
|
||||
return test_dict
|
||||
|
||||
|
||||
T = TypeVar("T", bound="Model")
|
||||
|
||||
|
||||
class Model(NewBaseModel):
|
||||
__abstract__ = False
|
||||
if TYPE_CHECKING: # pragma nocover
|
||||
Meta: ModelMeta
|
||||
|
||||
def __repr__(self) -> str: # pragma nocover
|
||||
attrs_to_include = ["tablename", "columns", "pkname"]
|
||||
_repr = {k: v for k, v in self.Meta.model_fields.items()}
|
||||
for atr in attrs_to_include:
|
||||
_repr[atr] = getattr(self.Meta, atr)
|
||||
return f"{self.__class__.__name__}({str(_repr)})"
|
||||
|
||||
@classmethod
|
||||
def from_row( # noqa CCR001
|
||||
cls,
|
||||
cls: Type[T],
|
||||
row: sqlalchemy.engine.ResultProxy,
|
||||
select_related: List = None,
|
||||
related_models: Any = None,
|
||||
previous_table: str = None,
|
||||
fields: List = None,
|
||||
) -> Optional["Model"]:
|
||||
) -> Optional[T]:
|
||||
|
||||
item: Dict[str, Any] = {}
|
||||
select_related = select_related or []
|
||||
@ -66,7 +79,9 @@ class Model(NewBaseModel):
|
||||
item, row, table_prefix, fields, nested=table_prefix != ""
|
||||
)
|
||||
|
||||
instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None
|
||||
instance: Optional[T] = cls(**item) if item.get(
|
||||
cls.Meta.pkname, None
|
||||
) is not None else None
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
@ -124,7 +139,7 @@ class Model(NewBaseModel):
|
||||
|
||||
return item
|
||||
|
||||
async def save(self) -> "Model":
|
||||
async def save(self: T) -> T:
|
||||
self_fields = self._extract_model_db_fields()
|
||||
|
||||
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
||||
@ -137,7 +152,7 @@ class Model(NewBaseModel):
|
||||
setattr(self, self.Meta.pkname, item_id)
|
||||
return self
|
||||
|
||||
async def update(self, **kwargs: Any) -> "Model":
|
||||
async def update(self: T, **kwargs: Any) -> T:
|
||||
if kwargs:
|
||||
new_values = {**self.dict(), **kwargs}
|
||||
self.from_dict(new_values)
|
||||
@ -151,13 +166,13 @@ class Model(NewBaseModel):
|
||||
await self.Meta.database.execute(expr)
|
||||
return self
|
||||
|
||||
async def delete(self) -> int:
|
||||
async def delete(self: T) -> int:
|
||||
expr = self.Meta.table.delete()
|
||||
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
|
||||
result = await self.Meta.database.execute(expr)
|
||||
return result
|
||||
|
||||
async def load(self) -> "Model":
|
||||
async def load(self: T) -> T:
|
||||
expr = self.Meta.table.select().where(self.pk_column == self.pk)
|
||||
row = await self.Meta.database.fetch_one(expr)
|
||||
if not row: # pragma nocover
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Dict, List, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import RelationshipInstanceError
|
||||
@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
from ormar.models import NewBaseModel
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
Field = TypeVar("Field", bound=BaseField)
|
||||
|
||||
|
||||
@ -135,7 +137,7 @@ class ModelTableProxy:
|
||||
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
||||
return name
|
||||
# fallback for not registered relation
|
||||
if register_missing:
|
||||
if register_missing: # pragma nocover
|
||||
expand_reverse_relationships(related.__class__) # type: ignore
|
||||
return ModelTableProxy.resolve_relation_name(
|
||||
item, related, register_missing=False
|
||||
@ -177,7 +179,7 @@ class ModelTableProxy:
|
||||
return new_kwargs
|
||||
|
||||
@classmethod
|
||||
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
||||
def merge_instances_list(cls, result_rows: Sequence["Model"]) -> Sequence["Model"]:
|
||||
merged_rows: List["Model"] = []
|
||||
for index, model in enumerate(result_rows):
|
||||
if index > 0 and model is not None and model.pk == merged_rows[-1].pk:
|
||||
|
||||
@ -5,11 +5,12 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -27,7 +28,9 @@ from ormar.relations.alias_manager import AliasManager
|
||||
from ormar.relations.relation_manager import RelationsManager
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models.model import Model
|
||||
from ormar import Model
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
IntStr = Union[int, str]
|
||||
DictStrAny = Dict[str, Any]
|
||||
@ -52,7 +55,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
Meta: ModelMeta
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
|
||||
|
||||
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
|
||||
object.__setattr__(self, "_orm_saved", False)
|
||||
@ -73,7 +76,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
if "pk" in kwargs:
|
||||
kwargs[self.Meta.pkname] = kwargs.pop("pk")
|
||||
# build the models to set them and validate but don't register
|
||||
kwargs = {
|
||||
new_kwargs = {
|
||||
k: self._convert_json(
|
||||
k,
|
||||
self.Meta.model_fields[k].expand_relationship(
|
||||
@ -85,7 +88,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
}
|
||||
|
||||
values, fields_set, validation_error = pydantic.validate_model(
|
||||
self, kwargs # type: ignore
|
||||
self, new_kwargs # type: ignore
|
||||
)
|
||||
if validation_error and not pk_only:
|
||||
raise validation_error
|
||||
@ -96,7 +99,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
# register the columns models after initialization
|
||||
for related in self.extract_related_names():
|
||||
self.Meta.model_fields[related].expand_relationship(
|
||||
kwargs.get(related), self, to_register=True
|
||||
new_kwargs.get(related), self, to_register=True
|
||||
)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
|
||||
@ -133,7 +136,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
|
||||
def _extract_related_model_instead_of_field(
|
||||
self, item: str
|
||||
) -> Optional[Union["Model", List["Model"]]]:
|
||||
) -> Optional[Union[T, Sequence[T]]]:
|
||||
alias = self.get_column_alias(item)
|
||||
if alias in self._orm:
|
||||
return self._orm.get(alias)
|
||||
@ -170,7 +173,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
def db_backend_name(cls) -> str:
|
||||
return cls.Meta.database._backend._dialect.name
|
||||
|
||||
def remove(self, name: "Model") -> None:
|
||||
def remove(self, name: T) -> None:
|
||||
self._orm.remove_parent(self, name)
|
||||
|
||||
def dict( # noqa A003
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
|
||||
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union
|
||||
|
||||
import databases
|
||||
import sqlalchemy
|
||||
@ -59,7 +59,7 @@ class QuerySet:
|
||||
raise ValueError("Model class of QuerySet is not initialized")
|
||||
return self.model_cls
|
||||
|
||||
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
|
||||
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
|
||||
result_rows = [
|
||||
self.model.from_row(
|
||||
row, select_related=self._select_related, fields=self._columns
|
||||
@ -87,7 +87,7 @@ class QuerySet:
|
||||
return new_kwargs
|
||||
|
||||
@staticmethod
|
||||
def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None:
|
||||
def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None:
|
||||
if not rows or rows[0] is None:
|
||||
raise NoMatch()
|
||||
if len(rows) > 1:
|
||||
@ -267,7 +267,7 @@ class QuerySet:
|
||||
model = await self.get(pk=kwargs[pk_name])
|
||||
return await model.update(**kwargs)
|
||||
|
||||
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
|
||||
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
|
||||
if kwargs:
|
||||
return await self.filter(**kwargs).all()
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import ormar
|
||||
|
||||
@ -7,6 +7,8 @@ if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models import Model
|
||||
from ormar.queryset import QuerySet
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
|
||||
class QuerysetProxy:
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
@ -26,27 +28,28 @@ class QuerysetProxy:
|
||||
def queryset(self, value: "QuerySet") -> None:
|
||||
self._queryset = value
|
||||
|
||||
def _assign_child_to_parent(self, child: Optional["Model"]) -> None:
|
||||
def _assign_child_to_parent(self, child: Optional[T]) -> None:
|
||||
if child:
|
||||
owner = self.relation._owner
|
||||
rel_name = owner.resolve_relation_name(owner, child)
|
||||
setattr(owner, rel_name, child)
|
||||
|
||||
def _register_related(self, child: Union["Model", List[Optional["Model"]]]) -> None:
|
||||
def _register_related(self, child: Union[T, Sequence[Optional[T]]]) -> None:
|
||||
if isinstance(child, list):
|
||||
for subchild in child:
|
||||
self._assign_child_to_parent(subchild)
|
||||
else:
|
||||
assert isinstance(child, Model)
|
||||
self._assign_child_to_parent(child)
|
||||
|
||||
async def create_through_instance(self, child: "Model") -> None:
|
||||
async def create_through_instance(self, child: T) -> None:
|
||||
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.relation._owner.get_name()
|
||||
child_column = child.get_name()
|
||||
kwargs = {owner_column: self.relation._owner, child_column: child}
|
||||
await queryset.create(**kwargs)
|
||||
|
||||
async def delete_through_instance(self, child: "Model") -> None:
|
||||
async def delete_through_instance(self, child: T) -> None:
|
||||
queryset = ormar.QuerySet(model_cls=self.relation.through)
|
||||
owner_column = self.relation._owner.get_name()
|
||||
child_column = child.get_name()
|
||||
@ -88,7 +91,7 @@ class QuerysetProxy:
|
||||
self._register_related(get)
|
||||
return get
|
||||
|
||||
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
|
||||
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
|
||||
all_items = await self.queryset.all(**kwargs)
|
||||
self._register_related(all_items)
|
||||
return all_items
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, TYPE_CHECKING, Type, Union
|
||||
from typing import List, Optional, TYPE_CHECKING, Type, TypeVar, Union
|
||||
|
||||
import ormar # noqa I100
|
||||
from ormar.exceptions import RelationshipInstanceError # noqa I100
|
||||
@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.relations import RelationsManager
|
||||
from ormar.models import NewBaseModel
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
|
||||
class RelationType(Enum):
|
||||
PRIMARY = 1
|
||||
@ -23,15 +25,15 @@ class Relation:
|
||||
self,
|
||||
manager: "RelationsManager",
|
||||
type_: RelationType,
|
||||
to: Type["Model"],
|
||||
through: Type["Model"] = None,
|
||||
to: Type[T],
|
||||
through: Type[T] = None,
|
||||
) -> None:
|
||||
self.manager = manager
|
||||
self._owner: "Model" = manager.owner
|
||||
self._type: RelationType = type_
|
||||
self.to: Type["Model"] = to
|
||||
self.through: Optional[Type["Model"]] = through
|
||||
self.related_models: Optional[Union[RelationProxy, "Model"]] = (
|
||||
self.to: Type[T] = to
|
||||
self.through: Optional[Type[T]] = through
|
||||
self.related_models: Optional[Union[RelationProxy, T]] = (
|
||||
RelationProxy(relation=self)
|
||||
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
|
||||
else None
|
||||
@ -50,7 +52,7 @@ class Relation:
|
||||
self.related_models.pop(ind)
|
||||
return None
|
||||
|
||||
def add(self, child: "Model") -> None:
|
||||
def add(self, child: T) -> None:
|
||||
relation_name = self._owner.resolve_relation_name(self._owner, child)
|
||||
if self._type == RelationType.PRIMARY:
|
||||
self.related_models = child
|
||||
@ -77,7 +79,7 @@ class Relation:
|
||||
self.related_models.pop(position) # type: ignore
|
||||
del self._owner.__dict__[relation_name][position]
|
||||
|
||||
def get(self) -> Optional[Union[List["Model"], "Model"]]:
|
||||
def get(self) -> Optional[Union[List[T], T]]:
|
||||
return self.related_models
|
||||
|
||||
def __repr__(self) -> str: # pragma no cover
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Type, Union
|
||||
from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from weakref import proxy
|
||||
|
||||
from ormar.fields import BaseField
|
||||
@ -14,6 +14,8 @@ if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
from ormar.models import NewBaseModel
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
|
||||
class RelationsManager:
|
||||
def __init__(
|
||||
@ -46,7 +48,7 @@ class RelationsManager:
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self._related_names
|
||||
|
||||
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
|
||||
def get(self, name: str) -> Optional[Union[T, Sequence[T]]]:
|
||||
relation = self._relations.get(name, None)
|
||||
if relation is not None:
|
||||
return relation.get()
|
||||
|
||||
@ -72,6 +72,6 @@ class RelationProxy(list):
|
||||
if self.relation._type == ormar.RelationType.MULTIPLE:
|
||||
await self.queryset_proxy.create_through_instance(item)
|
||||
rel_name = item.resolve_relation_name(item, self._owner)
|
||||
if rel_name not in item._orm:
|
||||
if rel_name not in item._orm: # pragma nocover
|
||||
item._orm._add_relation(item.Meta.model_fields[rel_name])
|
||||
setattr(item, rel_name, self._owner)
|
||||
|
||||
Reference in New Issue
Block a user