allow change to build in type hints
This commit is contained in:
@ -5,7 +5,8 @@ import databases
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import BaseConfig
|
||||
from pydantic.fields import FieldInfo, ModelField
|
||||
from pydantic.fields import ModelField
|
||||
from pydantic.utils import lenient_issubclass
|
||||
from sqlalchemy.sql.schema import ColumnCollectionConstraint
|
||||
|
||||
import ormar # noqa I100
|
||||
@ -179,44 +180,58 @@ def register_relation_in_alias_manager(
|
||||
|
||||
|
||||
def populate_default_pydantic_field_value(
|
||||
type_: Type[BaseField], field: str, attrs: dict
|
||||
ormar_field: Type[BaseField], field_name: str, attrs: dict
|
||||
) -> dict:
|
||||
def_value = type_.default_value()
|
||||
curr_def_value = attrs.get(field, "NONE")
|
||||
if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
|
||||
attrs[field] = def_value
|
||||
elif curr_def_value == "NONE" and type_.nullable:
|
||||
attrs[field] = FieldInfo(default=None)
|
||||
curr_def_value = attrs.get(field_name, ormar.Undefined)
|
||||
if lenient_issubclass(curr_def_value, ormar.fields.BaseField):
|
||||
curr_def_value = ormar.Undefined
|
||||
if curr_def_value is None:
|
||||
attrs[field_name] = ormar_field.convert_to_pydantic_field_info(allow_null=True)
|
||||
else:
|
||||
attrs[field_name] = ormar_field.convert_to_pydantic_field_info()
|
||||
return attrs
|
||||
|
||||
|
||||
def populate_pydantic_default_values(attrs: Dict) -> Dict:
|
||||
for field, type_ in attrs["__annotations__"].items():
|
||||
if issubclass(type_, BaseField):
|
||||
if type_.name is None:
|
||||
type_.name = field
|
||||
attrs = populate_default_pydantic_field_value(type_, field, attrs)
|
||||
return attrs
|
||||
def check_if_field_annotation_or_value_is_ormar(
|
||||
field: Any, field_name: str, attrs: Dict
|
||||
) -> bool:
|
||||
return lenient_issubclass(field, BaseField) or issubclass(
|
||||
attrs.get(field_name, type), BaseField
|
||||
)
|
||||
|
||||
|
||||
def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
|
||||
def extract_field_from_annotation_or_value(
|
||||
field: Any, field_name: str, attrs: Dict
|
||||
) -> Type[ormar.fields.BaseField]:
|
||||
return field if lenient_issubclass(field, BaseField) else attrs.get(field_name)
|
||||
|
||||
|
||||
def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
|
||||
model_fields = {}
|
||||
for field_name, field in attrs["__annotations__"].items():
|
||||
# ormar fields can be used as annotation or as default value
|
||||
if check_if_field_annotation_or_value_is_ormar(field, field_name, attrs):
|
||||
ormar_field = extract_field_from_annotation_or_value(
|
||||
field, field_name, attrs
|
||||
)
|
||||
if ormar_field.name is None:
|
||||
ormar_field.name = field_name
|
||||
attrs = populate_default_pydantic_field_value(
|
||||
ormar_field, field_name, attrs
|
||||
)
|
||||
model_fields[field_name] = ormar_field
|
||||
attrs["__annotations__"][field_name] = ormar_field.__type__
|
||||
return attrs, model_fields
|
||||
|
||||
|
||||
def extract_annotations_and_default_vals(
|
||||
attrs: dict, bases: Tuple
|
||||
) -> Tuple[Dict, Dict]:
|
||||
attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get(
|
||||
"__annotations__", {}
|
||||
)
|
||||
attrs = populate_pydantic_default_values(attrs)
|
||||
return attrs
|
||||
|
||||
|
||||
def populate_meta_orm_model_fields(
|
||||
attrs: dict, new_model: Type["Model"]
|
||||
) -> Type["Model"]:
|
||||
model_fields = {
|
||||
field_name: field
|
||||
for field_name, field in attrs["__annotations__"].items()
|
||||
if issubclass(field, BaseField)
|
||||
}
|
||||
new_model.Meta.model_fields = model_fields
|
||||
return new_model
|
||||
attrs, model_fields = populate_pydantic_default_values(attrs)
|
||||
return attrs, model_fields
|
||||
|
||||
|
||||
def populate_meta_tablename_columns_and_pk(
|
||||
@ -305,7 +320,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
) -> "ModelMetaclass":
|
||||
attrs["Config"] = get_pydantic_base_orm_config()
|
||||
attrs["__name__"] = name
|
||||
attrs = extract_annotations_and_default_vals(attrs, bases)
|
||||
attrs, model_fields = extract_annotations_and_default_vals(attrs, bases)
|
||||
new_model = super().__new__( # type: ignore
|
||||
mcs, name, bases, attrs
|
||||
)
|
||||
@ -313,7 +328,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
if hasattr(new_model, "Meta"):
|
||||
if not hasattr(new_model.Meta, "constraints"):
|
||||
new_model.Meta.constraints = []
|
||||
new_model = populate_meta_orm_model_fields(attrs, new_model)
|
||||
if not hasattr(new_model.Meta, "model_fields"):
|
||||
new_model.Meta.model_fields = model_fields
|
||||
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
|
||||
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
|
||||
expand_reverse_relationships(new_model)
|
||||
@ -322,7 +338,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
if new_model.Meta.pkname not in attrs["__annotations__"]:
|
||||
field_name = new_model.Meta.pkname
|
||||
field = Integer(name=field_name, primary_key=True)
|
||||
attrs["__annotations__"][field_name] = field
|
||||
attrs["__annotations__"][field_name] = Optional[int] # type: ignore
|
||||
populate_default_pydantic_field_value(
|
||||
field, field_name, attrs # type: ignore
|
||||
)
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
import ormar.queryset # noqa I100
|
||||
from ormar.fields.many_to_many import ManyToManyField
|
||||
from ormar.models import NewBaseModel # noqa I100
|
||||
from ormar.models.metaclass import ModelMeta
|
||||
|
||||
|
||||
def group_related_list(list_: List) -> Dict:
|
||||
@ -23,18 +24,30 @@ def group_related_list(list_: List) -> Dict:
|
||||
return test_dict
|
||||
|
||||
|
||||
T = TypeVar("T", bound="Model")
|
||||
|
||||
|
||||
class Model(NewBaseModel):
|
||||
__abstract__ = False
|
||||
if TYPE_CHECKING: # pragma nocover
|
||||
Meta: ModelMeta
|
||||
|
||||
def __repr__(self) -> str: # pragma nocover
|
||||
attrs_to_include = ["tablename", "columns", "pkname"]
|
||||
_repr = {k: v for k, v in self.Meta.model_fields.items()}
|
||||
for atr in attrs_to_include:
|
||||
_repr[atr] = getattr(self.Meta, atr)
|
||||
return f"{self.__class__.__name__}({str(_repr)})"
|
||||
|
||||
@classmethod
|
||||
def from_row( # noqa CCR001
|
||||
cls,
|
||||
cls: Type[T],
|
||||
row: sqlalchemy.engine.ResultProxy,
|
||||
select_related: List = None,
|
||||
related_models: Any = None,
|
||||
previous_table: str = None,
|
||||
fields: List = None,
|
||||
) -> Optional["Model"]:
|
||||
) -> Optional[T]:
|
||||
|
||||
item: Dict[str, Any] = {}
|
||||
select_related = select_related or []
|
||||
@ -66,7 +79,9 @@ class Model(NewBaseModel):
|
||||
item, row, table_prefix, fields, nested=table_prefix != ""
|
||||
)
|
||||
|
||||
instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None
|
||||
instance: Optional[T] = cls(**item) if item.get(
|
||||
cls.Meta.pkname, None
|
||||
) is not None else None
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
@ -124,7 +139,7 @@ class Model(NewBaseModel):
|
||||
|
||||
return item
|
||||
|
||||
async def save(self) -> "Model":
|
||||
async def save(self: T) -> T:
|
||||
self_fields = self._extract_model_db_fields()
|
||||
|
||||
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
|
||||
@ -137,7 +152,7 @@ class Model(NewBaseModel):
|
||||
setattr(self, self.Meta.pkname, item_id)
|
||||
return self
|
||||
|
||||
async def update(self, **kwargs: Any) -> "Model":
|
||||
async def update(self: T, **kwargs: Any) -> T:
|
||||
if kwargs:
|
||||
new_values = {**self.dict(), **kwargs}
|
||||
self.from_dict(new_values)
|
||||
@ -151,13 +166,13 @@ class Model(NewBaseModel):
|
||||
await self.Meta.database.execute(expr)
|
||||
return self
|
||||
|
||||
async def delete(self) -> int:
|
||||
async def delete(self: T) -> int:
|
||||
expr = self.Meta.table.delete()
|
||||
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
|
||||
result = await self.Meta.database.execute(expr)
|
||||
return result
|
||||
|
||||
async def load(self) -> "Model":
|
||||
async def load(self: T) -> T:
|
||||
expr = self.Meta.table.select().where(self.pk_column == self.pk)
|
||||
row = await self.Meta.database.fetch_one(expr)
|
||||
if not row: # pragma nocover
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import inspect
|
||||
from typing import Dict, List, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union
|
||||
|
||||
import ormar
|
||||
from ormar.exceptions import RelationshipInstanceError
|
||||
@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover
|
||||
from ormar import Model
|
||||
from ormar.models import NewBaseModel
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
Field = TypeVar("Field", bound=BaseField)
|
||||
|
||||
|
||||
@ -135,7 +137,7 @@ class ModelTableProxy:
|
||||
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
||||
return name
|
||||
# fallback for not registered relation
|
||||
if register_missing:
|
||||
if register_missing: # pragma nocover
|
||||
expand_reverse_relationships(related.__class__) # type: ignore
|
||||
return ModelTableProxy.resolve_relation_name(
|
||||
item, related, register_missing=False
|
||||
@ -177,7 +179,7 @@ class ModelTableProxy:
|
||||
return new_kwargs
|
||||
|
||||
@classmethod
|
||||
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
||||
def merge_instances_list(cls, result_rows: Sequence["Model"]) -> Sequence["Model"]:
|
||||
merged_rows: List["Model"] = []
|
||||
for index, model in enumerate(result_rows):
|
||||
if index > 0 and model is not None and model.pk == merged_rows[-1].pk:
|
||||
|
||||
@ -5,11 +5,12 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -27,7 +28,9 @@ from ormar.relations.alias_manager import AliasManager
|
||||
from ormar.relations.relation_manager import RelationsManager
|
||||
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from ormar.models.model import Model
|
||||
from ormar import Model
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
IntStr = Union[int, str]
|
||||
DictStrAny = Dict[str, Any]
|
||||
@ -52,7 +55,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
Meta: ModelMeta
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
|
||||
|
||||
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
|
||||
object.__setattr__(self, "_orm_saved", False)
|
||||
@ -73,7 +76,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
if "pk" in kwargs:
|
||||
kwargs[self.Meta.pkname] = kwargs.pop("pk")
|
||||
# build the models to set them and validate but don't register
|
||||
kwargs = {
|
||||
new_kwargs = {
|
||||
k: self._convert_json(
|
||||
k,
|
||||
self.Meta.model_fields[k].expand_relationship(
|
||||
@ -85,7 +88,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
}
|
||||
|
||||
values, fields_set, validation_error = pydantic.validate_model(
|
||||
self, kwargs # type: ignore
|
||||
self, new_kwargs # type: ignore
|
||||
)
|
||||
if validation_error and not pk_only:
|
||||
raise validation_error
|
||||
@ -96,7 +99,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
# register the columns models after initialization
|
||||
for related in self.extract_related_names():
|
||||
self.Meta.model_fields[related].expand_relationship(
|
||||
kwargs.get(related), self, to_register=True
|
||||
new_kwargs.get(related), self, to_register=True
|
||||
)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
|
||||
@ -133,7 +136,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
|
||||
def _extract_related_model_instead_of_field(
|
||||
self, item: str
|
||||
) -> Optional[Union["Model", List["Model"]]]:
|
||||
) -> Optional[Union[T, Sequence[T]]]:
|
||||
alias = self.get_column_alias(item)
|
||||
if alias in self._orm:
|
||||
return self._orm.get(alias)
|
||||
@ -170,7 +173,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
def db_backend_name(cls) -> str:
|
||||
return cls.Meta.database._backend._dialect.name
|
||||
|
||||
def remove(self, name: "Model") -> None:
|
||||
def remove(self, name: T) -> None:
|
||||
self._orm.remove_parent(self, name)
|
||||
|
||||
def dict( # noqa A003
|
||||
|
||||
Reference in New Issue
Block a user