From 8fba94efa1d8524eef27a149e23e6b932ef78179 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 31 Oct 2020 15:43:34 +0100 Subject: [PATCH] allow change to build in type hints --- ormar/fields/base.py | 26 +- ormar/fields/foreign_key.py | 22 +- ormar/fields/many_to_many.py | 27 +- ormar/fields/model_fields.py | 28 ++- ormar/models/metaclass.py | 82 +++--- ormar/models/model.py | 31 ++- ormar/models/modelproxy.py | 8 +- ormar/models/newbasemodel.py | 19 +- ormar/queryset/queryset.py | 8 +- ormar/relations/querysetproxy.py | 15 +- ormar/relations/relation.py | 18 +- ormar/relations/relation_manager.py | 6 +- ormar/relations/relation_proxy.py | 2 +- tests/test_fastapi_docs.py | 4 +- tests/test_models.py | 4 +- tests/test_new_annotation_style.py | 374 ++++++++++++++++++++++++++++ tests/test_server_default.py | 30 +-- tests/test_unique_constraints.py | 2 +- 18 files changed, 575 insertions(+), 131 deletions(-) create mode 100644 tests/test_new_annotation_style.py diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 0a47ae4..e57dc64 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -1,5 +1,6 @@ from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +import pydantic import sqlalchemy from pydantic import Field, typing from pydantic.fields import FieldInfo @@ -11,8 +12,9 @@ if TYPE_CHECKING: # pragma no cover from ormar.models import NewBaseModel -class BaseField: +class BaseField(FieldInfo): __type__ = None + __pydantic_type__ = None column_type: sqlalchemy.Column constraints: List = [] @@ -32,6 +34,28 @@ class BaseField: default: Any server_default: Any + @classmethod + def is_valid_field_info_field(cls, field_name: str) -> bool: + return ( + field_name not in ["default", "default_factory"] + and not field_name.startswith("__") + and hasattr(cls, field_name) + ) + + @classmethod + def convert_to_pydantic_field_info(cls, allow_null: bool = False) -> FieldInfo: + base = cls.default_value() + if base is None: + base = ( + FieldInfo(default=None) + if (cls.nullable or allow_null) + else FieldInfo(default=pydantic.fields.Undefined) + ) + for attr_name in FieldInfo.__dict__.keys(): + if cls.is_valid_field_info_field(attr_name): + setattr(base, attr_name, cls.__dict__.get(attr_name)) + return base + @classmethod def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]: if cls.is_auto_primary_key(): diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 854f83d..c2f9911 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Any, Generator, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy from sqlalchemy import UniqueConstraint @@ -39,8 +39,15 @@ def ForeignKey( # noqa CFQ002 ondelete: str = None, ) -> Type["ForeignKeyField"]: fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname) - to_field = to.__fields__[to.Meta.pkname] + to_field = to.Meta.model_fields[to.Meta.pkname] + __type__ = ( + Union[to_field.__type__, to] + if not nullable + else Optional[Union[to_field.__type__, to]] + ) namespace = dict( + __type__=__type__, + __pydantic_type__=__type__, to=to, name=name, nullable=nullable, @@ -50,7 +57,7 @@ def ForeignKey( # noqa CFQ002 ) ], unique=unique, - column_type=to_field.type_.column_type, + column_type=to_field.column_type, related_name=related_name, virtual=virtual, primary_key=False, @@ -58,7 +65,6 @@ def ForeignKey( # noqa CFQ002 pydantic_only=False, default=None, server_default=None, - __pydantic_model__=to, ) return type("ForeignKey", (ForeignKeyField, BaseField), namespace) @@ -70,14 +76,6 @@ class ForeignKeyField(BaseField): related_name: str virtual: bool - @classmethod - def __get_validators__(cls) -> Generator: - yield cls.validate - - @classmethod - def validate(cls, value: Any) -> Any: - return value - @classmethod def _extract_model_from_sequence( cls, value: List, child: "Model", to_register: bool diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 1f73a0d..5c81182 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -1,4 +1,4 @@ -from typing import Dict, TYPE_CHECKING, Type +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField @@ -15,17 +15,26 @@ def ManyToMany( *, name: str = None, unique: bool = False, - related_name: str = None, virtual: bool = False, + **kwargs: Any ) -> Type["ManyToManyField"]: - to_field = to.__fields__[to.Meta.pkname] + to_field = to.Meta.model_fields[to.Meta.pkname] + related_name = kwargs.pop("related_name", None) + nullable = kwargs.pop("nullable", True) + __type__ = ( + Union[to_field.__type__, to, List[to]] # type: ignore + if not nullable + else Optional[Union[to_field.__type__, to, List[to]]] # type: ignore + ) namespace = dict( + __type__=__type__, + __pydantic_type__=__type__, to=to, through=through, name=name, nullable=True, unique=unique, - column_type=to_field.type_.column_type, + column_type=to_field.column_type, related_name=related_name, virtual=virtual, primary_key=False, @@ -33,9 +42,6 @@ def ManyToMany( pydantic_only=False, default=None, server_default=None, - __pydantic_model__=to, - # __origin__=List, - # __args__=[Optional[to]] ) return type("ManyToMany", (ManyToManyField, BaseField), namespace) @@ -43,10 +49,3 @@ def ManyToMany( class ManyToManyField(ForeignKeyField): through: Type["Model"] - - @classmethod - def __modify_schema__(cls, field_schema: Dict) -> None: - field_schema["type"] = "array" - field_schema["title"] = cls.name.title() - field_schema["definitions"] = {f"{cls.to.__name__}": cls.to.schema()} - field_schema["items"] = {"$ref": f"{REF_PREFIX}{cls.to.__name__}"} diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 5462865..2e98018 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -20,8 +20,9 @@ def is_field_nullable( class ModelFieldFactory: - _bases: Any = BaseField + _bases: Any = (BaseField,) _type: Any = None + _pydantic_type: Any = None def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore cls.validate(**kwargs) @@ -32,6 +33,7 @@ class ModelFieldFactory: namespace = dict( __type__=cls._type, + __pydantic_type__=cls._pydantic_type, name=kwargs.pop("name", None), primary_key=kwargs.pop("primary_key", False), default=default, @@ -57,8 +59,8 @@ class ModelFieldFactory: class String(ModelFieldFactory): - _bases = (pydantic.ConstrainedStr, BaseField) _type = str + _pydantic_type = pydantic.ConstrainedStr def __new__( # type: ignore # noqa CFQ002 cls, @@ -96,8 +98,8 @@ class String(ModelFieldFactory): class Integer(ModelFieldFactory): - _bases = (pydantic.ConstrainedInt, BaseField) _type = int + _pydantic_type = pydantic.ConstrainedInt def __new__( # type: ignore cls, @@ -131,8 +133,8 @@ class Integer(ModelFieldFactory): class Text(ModelFieldFactory): - _bases = (pydantic.ConstrainedStr, BaseField) _type = str + _pydantic_type = pydantic.ConstrainedStr def __new__( # type: ignore cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any @@ -154,8 +156,8 @@ class Text(ModelFieldFactory): class Float(ModelFieldFactory): - _bases = (pydantic.ConstrainedFloat, BaseField) _type = float + _pydantic_type = pydantic.ConstrainedFloat def __new__( # type: ignore cls, @@ -183,8 +185,8 @@ class Float(ModelFieldFactory): class Boolean(ModelFieldFactory): - _bases = (int, BaseField) _type = bool + _pydantic_type = bool @classmethod def get_column_type(cls, **kwargs: Any) -> Any: @@ -192,8 +194,8 @@ class Boolean(ModelFieldFactory): class DateTime(ModelFieldFactory): - _bases = (datetime.datetime, BaseField) _type = datetime.datetime + _pydantic_type = datetime.datetime @classmethod def get_column_type(cls, **kwargs: Any) -> Any: @@ -201,8 +203,8 @@ class DateTime(ModelFieldFactory): class Date(ModelFieldFactory): - _bases = (datetime.date, BaseField) _type = datetime.date + _pydantic_type = datetime.date @classmethod def get_column_type(cls, **kwargs: Any) -> Any: @@ -210,8 +212,8 @@ class Date(ModelFieldFactory): class Time(ModelFieldFactory): - _bases = (datetime.time, BaseField) _type = datetime.time + _pydantic_type = datetime.time @classmethod def get_column_type(cls, **kwargs: Any) -> Any: @@ -219,8 +221,8 @@ class Time(ModelFieldFactory): class JSON(ModelFieldFactory): - _bases = (pydantic.Json, BaseField) _type = pydantic.Json + _pydantic_type = pydantic.Json @classmethod def get_column_type(cls, **kwargs: Any) -> Any: @@ -228,8 +230,8 @@ class JSON(ModelFieldFactory): class BigInteger(Integer): - _bases = (pydantic.ConstrainedInt, BaseField) _type = int + _pydantic_type = pydantic.ConstrainedInt def __new__( # type: ignore cls, @@ -263,8 +265,8 @@ class BigInteger(Integer): class Decimal(ModelFieldFactory): - _bases = (pydantic.ConstrainedDecimal, BaseField) _type = decimal.Decimal + _pydantic_type = pydantic.ConstrainedDecimal def __new__( # type: ignore # noqa CFQ002 cls, @@ -318,8 +320,8 @@ class Decimal(ModelFieldFactory): class UUID(ModelFieldFactory): - _bases = (uuid.UUID, BaseField) _type = uuid.UUID + _pydantic_type = uuid.UUID @classmethod def get_column_type(cls, **kwargs: Any) -> Any: diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index d7c57b9..798424b 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -5,7 +5,8 @@ import databases import pydantic import sqlalchemy from pydantic import BaseConfig -from pydantic.fields import FieldInfo, ModelField +from pydantic.fields import ModelField +from pydantic.utils import lenient_issubclass from sqlalchemy.sql.schema import ColumnCollectionConstraint import ormar # noqa I100 @@ -179,44 +180,58 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - type_: Type[BaseField], field: str, attrs: dict + ormar_field: Type[BaseField], field_name: str, attrs: dict ) -> dict: - def_value = type_.default_value() - curr_def_value = attrs.get(field, "NONE") - if curr_def_value == "NONE" and isinstance(def_value, FieldInfo): - attrs[field] = def_value - elif curr_def_value == "NONE" and type_.nullable: - attrs[field] = FieldInfo(default=None) + curr_def_value = attrs.get(field_name, ormar.Undefined) + if lenient_issubclass(curr_def_value, ormar.fields.BaseField): + curr_def_value = ormar.Undefined + if curr_def_value is None: + attrs[field_name] = ormar_field.convert_to_pydantic_field_info(allow_null=True) + else: + attrs[field_name] = ormar_field.convert_to_pydantic_field_info() return attrs -def populate_pydantic_default_values(attrs: Dict) -> Dict: - for field, type_ in attrs["__annotations__"].items(): - if issubclass(type_, BaseField): - if type_.name is None: - type_.name = field - attrs = populate_default_pydantic_field_value(type_, field, attrs) - return attrs +def check_if_field_annotation_or_value_is_ormar( + field: Any, field_name: str, attrs: Dict +) -> bool: + return lenient_issubclass(field, BaseField) or issubclass( + attrs.get(field_name, type), BaseField + ) -def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: +def extract_field_from_annotation_or_value( + field: Any, field_name: str, attrs: Dict +) -> Type[ormar.fields.BaseField]: + return field if lenient_issubclass(field, BaseField) else attrs.get(field_name) + + +def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]: + model_fields = {} + for field_name, field in attrs["__annotations__"].items(): + # ormar fields can be used as annotation or as default value + if check_if_field_annotation_or_value_is_ormar(field, field_name, attrs): + ormar_field = extract_field_from_annotation_or_value( + field, field_name, attrs + ) + if ormar_field.name is None: + ormar_field.name = field_name + attrs = populate_default_pydantic_field_value( + ormar_field, field_name, attrs + ) + model_fields[field_name] = ormar_field + attrs["__annotations__"][field_name] = ormar_field.__type__ + return attrs, model_fields + + +def extract_annotations_and_default_vals( + attrs: dict, bases: Tuple +) -> Tuple[Dict, Dict]: attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get( "__annotations__", {} ) - attrs = populate_pydantic_default_values(attrs) - return attrs - - -def populate_meta_orm_model_fields( - attrs: dict, new_model: Type["Model"] -) -> Type["Model"]: - model_fields = { - field_name: field - for field_name, field in attrs["__annotations__"].items() - if issubclass(field, BaseField) - } - new_model.Meta.model_fields = model_fields - return new_model + attrs, model_fields = populate_pydantic_default_values(attrs) + return attrs, model_fields def populate_meta_tablename_columns_and_pk( @@ -305,7 +320,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name - attrs = extract_annotations_and_default_vals(attrs, bases) + attrs, model_fields = extract_annotations_and_default_vals(attrs, bases) new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) @@ -313,7 +328,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): if hasattr(new_model, "Meta"): if not hasattr(new_model.Meta, "constraints"): new_model.Meta.constraints = [] - new_model = populate_meta_orm_model_fields(attrs, new_model) + if not hasattr(new_model.Meta, "model_fields"): + new_model.Meta.model_fields = model_fields new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model) expand_reverse_relationships(new_model) @@ -322,7 +338,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): if new_model.Meta.pkname not in attrs["__annotations__"]: field_name = new_model.Meta.pkname field = Integer(name=field_name, primary_key=True) - attrs["__annotations__"][field_name] = field + attrs["__annotations__"][field_name] = Optional[int] # type: ignore populate_default_pydantic_field_value( field, field_name, attrs # type: ignore ) diff --git a/ormar/models/model.py b/ormar/models/model.py index d0b07e4..f8884b1 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,11 +1,12 @@ import itertools -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar import sqlalchemy import ormar.queryset # noqa I100 from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 +from ormar.models.metaclass import ModelMeta def group_related_list(list_: List) -> Dict: @@ -23,18 +24,30 @@ def group_related_list(list_: List) -> Dict: return test_dict +T = TypeVar("T", bound="Model") + + class Model(NewBaseModel): __abstract__ = False + if TYPE_CHECKING: # pragma nocover + Meta: ModelMeta + + def __repr__(self) -> str: # pragma nocover + attrs_to_include = ["tablename", "columns", "pkname"] + _repr = {k: v for k, v in self.Meta.model_fields.items()} + for atr in attrs_to_include: + _repr[atr] = getattr(self.Meta, atr) + return f"{self.__class__.__name__}({str(_repr)})" @classmethod def from_row( # noqa CCR001 - cls, + cls: Type[T], row: sqlalchemy.engine.ResultProxy, select_related: List = None, related_models: Any = None, previous_table: str = None, fields: List = None, - ) -> Optional["Model"]: + ) -> Optional[T]: item: Dict[str, Any] = {} select_related = select_related or [] @@ -66,7 +79,9 @@ class Model(NewBaseModel): item, row, table_prefix, fields, nested=table_prefix != "" ) - instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None + instance: Optional[T] = cls(**item) if item.get( + cls.Meta.pkname, None + ) is not None else None return instance @classmethod @@ -124,7 +139,7 @@ class Model(NewBaseModel): return item - async def save(self) -> "Model": + async def save(self: T) -> T: self_fields = self._extract_model_db_fields() if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: @@ -137,7 +152,7 @@ class Model(NewBaseModel): setattr(self, self.Meta.pkname, item_id) return self - async def update(self, **kwargs: Any) -> "Model": + async def update(self: T, **kwargs: Any) -> T: if kwargs: new_values = {**self.dict(), **kwargs} self.from_dict(new_values) @@ -151,13 +166,13 @@ class Model(NewBaseModel): await self.Meta.database.execute(expr) return self - async def delete(self) -> int: + async def delete(self: T) -> int: expr = self.Meta.table.delete() expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) result = await self.Meta.database.execute(expr) return result - async def load(self) -> "Model": + async def load(self: T) -> T: expr = self.Meta.table.select().where(self.pk_column == self.pk) row = await self.Meta.database.fetch_one(expr) if not row: # pragma nocover diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 56ee930..fb0e789 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,5 +1,5 @@ import inspect -from typing import Dict, List, Set, TYPE_CHECKING, Type, TypeVar, Union +from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union import ormar from ormar.exceptions import RelationshipInstanceError @@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover from ormar import Model from ormar.models import NewBaseModel + T = TypeVar("T", bound=Model) + Field = TypeVar("Field", bound=BaseField) @@ -135,7 +137,7 @@ class ModelTableProxy: if field.to == related.__class__ or field.to.Meta == related.Meta: return name # fallback for not registered relation - if register_missing: + if register_missing: # pragma nocover expand_reverse_relationships(related.__class__) # type: ignore return ModelTableProxy.resolve_relation_name( item, related, register_missing=False @@ -177,7 +179,7 @@ class ModelTableProxy: return new_kwargs @classmethod - def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: + def merge_instances_list(cls, result_rows: Sequence["Model"]) -> Sequence["Model"]: merged_rows: List["Model"] = [] for index, model in enumerate(result_rows): if index > 0 and model is not None and model.pk == merged_rows[-1].pk: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index faab1c1..72083d3 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -5,11 +5,12 @@ from typing import ( Any, Callable, Dict, - List, Mapping, Optional, + Sequence, TYPE_CHECKING, Type, + TypeVar, Union, ) @@ -27,7 +28,9 @@ from ormar.relations.alias_manager import AliasManager from ormar.relations.relation_manager import RelationsManager if TYPE_CHECKING: # pragma no cover - from ormar.models.model import Model + from ormar import Model + + T = TypeVar("T", bound=Model) IntStr = Union[int, str] DictStrAny = Dict[str, Any] @@ -52,7 +55,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass Meta: ModelMeta # noinspection PyMissingConstructor - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_saved", False) @@ -73,7 +76,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") # build the models to set them and validate but don't register - kwargs = { + new_kwargs = { k: self._convert_json( k, self.Meta.model_fields[k].expand_relationship( @@ -85,7 +88,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass } values, fields_set, validation_error = pydantic.validate_model( - self, kwargs # type: ignore + self, new_kwargs # type: ignore ) if validation_error and not pk_only: raise validation_error @@ -96,7 +99,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass # register the columns models after initialization for related in self.extract_related_names(): self.Meta.model_fields[related].expand_relationship( - kwargs.get(related), self, to_register=True + new_kwargs.get(related), self, to_register=True ) def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 @@ -133,7 +136,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def _extract_related_model_instead_of_field( self, item: str - ) -> Optional[Union["Model", List["Model"]]]: + ) -> Optional[Union[T, Sequence[T]]]: alias = self.get_column_alias(item) if alias in self._orm: return self._orm.get(alias) @@ -170,7 +173,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def db_backend_name(cls) -> str: return cls.Meta.database._backend._dialect.name - def remove(self, name: "Model") -> None: + def remove(self, name: T) -> None: self._orm.remove_parent(self, name) def dict( # noqa A003 diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 1decda1..5b58c55 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union import databases import sqlalchemy @@ -59,7 +59,7 @@ class QuerySet: raise ValueError("Model class of QuerySet is not initialized") return self.model_cls - def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]: + def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]: result_rows = [ self.model.from_row( row, select_related=self._select_related, fields=self._columns @@ -87,7 +87,7 @@ class QuerySet: return new_kwargs @staticmethod - def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None: + def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None: if not rows or rows[0] is None: raise NoMatch() if len(rows) > 1: @@ -267,7 +267,7 @@ class QuerySet: model = await self.get(pk=kwargs[pk_name]) return await model.update(**kwargs) - async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003 + async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 if kwargs: return await self.filter(**kwargs).all() diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index f71de87..b9fb250 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Union +from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union import ormar @@ -7,6 +7,8 @@ if TYPE_CHECKING: # pragma no cover from ormar.models import Model from ormar.queryset import QuerySet + T = TypeVar("T", bound=Model) + class QuerysetProxy: if TYPE_CHECKING: # pragma no cover @@ -26,27 +28,28 @@ class QuerysetProxy: def queryset(self, value: "QuerySet") -> None: self._queryset = value - def _assign_child_to_parent(self, child: Optional["Model"]) -> None: + def _assign_child_to_parent(self, child: Optional[T]) -> None: if child: owner = self.relation._owner rel_name = owner.resolve_relation_name(owner, child) setattr(owner, rel_name, child) - def _register_related(self, child: Union["Model", List[Optional["Model"]]]) -> None: + def _register_related(self, child: Union[T, Sequence[Optional[T]]]) -> None: if isinstance(child, list): for subchild in child: self._assign_child_to_parent(subchild) else: + assert isinstance(child, Model) self._assign_child_to_parent(child) - async def create_through_instance(self, child: "Model") -> None: + async def create_through_instance(self, child: T) -> None: queryset = ormar.QuerySet(model_cls=self.relation.through) owner_column = self.relation._owner.get_name() child_column = child.get_name() kwargs = {owner_column: self.relation._owner, child_column: child} await queryset.create(**kwargs) - async def delete_through_instance(self, child: "Model") -> None: + async def delete_through_instance(self, child: T) -> None: queryset = ormar.QuerySet(model_cls=self.relation.through) owner_column = self.relation._owner.get_name() child_column = child.get_name() @@ -88,7 +91,7 @@ class QuerysetProxy: self._register_related(get) return get - async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003 + async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 all_items = await self.queryset.all(**kwargs) self._register_related(all_items) return all_items diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index 91e00df..6520702 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional, TYPE_CHECKING, Type, Union +from typing import List, Optional, TYPE_CHECKING, Type, TypeVar, Union import ormar # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100 @@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover from ormar.relations import RelationsManager from ormar.models import NewBaseModel + T = TypeVar("T", bound=Model) + class RelationType(Enum): PRIMARY = 1 @@ -23,15 +25,15 @@ class Relation: self, manager: "RelationsManager", type_: RelationType, - to: Type["Model"], - through: Type["Model"] = None, + to: Type[T], + through: Type[T] = None, ) -> None: self.manager = manager self._owner: "Model" = manager.owner self._type: RelationType = type_ - self.to: Type["Model"] = to - self.through: Optional[Type["Model"]] = through - self.related_models: Optional[Union[RelationProxy, "Model"]] = ( + self.to: Type[T] = to + self.through: Optional[Type[T]] = through + self.related_models: Optional[Union[RelationProxy, T]] = ( RelationProxy(relation=self) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) else None @@ -50,7 +52,7 @@ class Relation: self.related_models.pop(ind) return None - def add(self, child: "Model") -> None: + def add(self, child: T) -> None: relation_name = self._owner.resolve_relation_name(self._owner, child) if self._type == RelationType.PRIMARY: self.related_models = child @@ -77,7 +79,7 @@ class Relation: self.related_models.pop(position) # type: ignore del self._owner.__dict__[relation_name][position] - def get(self) -> Optional[Union[List["Model"], "Model"]]: + def get(self) -> Optional[Union[List[T], T]]: return self.related_models def __repr__(self) -> str: # pragma no cover diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index 6e7eb24..82a6887 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, TYPE_CHECKING, Type, Union +from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, TypeVar, Union from weakref import proxy from ormar.fields import BaseField @@ -14,6 +14,8 @@ if TYPE_CHECKING: # pragma no cover from ormar import Model from ormar.models import NewBaseModel + T = TypeVar("T", bound=Model) + class RelationsManager: def __init__( @@ -46,7 +48,7 @@ class RelationsManager: def __contains__(self, item: str) -> bool: return item in self._related_names - def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]: + def get(self, name: str) -> Optional[Union[T, Sequence[T]]]: relation = self._relations.get(name, None) if relation is not None: return relation.get() diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 29e4b97..806a07d 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -72,6 +72,6 @@ class RelationProxy(list): if self.relation._type == ormar.RelationType.MULTIPLE: await self.queryset_proxy.create_through_instance(item) rel_name = item.resolve_relation_name(item, self._owner) - if rel_name not in item._orm: + if rel_name not in item._orm: # pragma nocover item._orm._add_relation(item.Meta.model_fields[rel_name]) setattr(item, rel_name, self._owner) diff --git a/tests/test_fastapi_docs.py b/tests/test_fastapi_docs.py index ed956c6..4cb1df3 100644 --- a/tests/test_fastapi_docs.py +++ b/tests/test_fastapi_docs.py @@ -125,7 +125,9 @@ def test_all_endpoints(): def test_schema_modification(): schema = Item.schema() - assert schema["properties"]["categories"]["type"] == "array" + assert any( + x.get("type") == "array" for x in schema["properties"]["categories"]["anyOf"] + ) assert schema["properties"]["categories"]["title"] == "Categories" diff --git a/tests/test_models.py b/tests/test_models.py index 47b5d50..cc783e3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -98,9 +98,9 @@ async def create_test_database(): def test_model_class(): assert list(User.Meta.model_fields.keys()) == ["id", "name"] - assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt) + assert issubclass(User.Meta.model_fields["id"], pydantic.fields.FieldInfo) assert User.Meta.model_fields["id"].primary_key is True - assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr) + assert issubclass(User.Meta.model_fields["name"], pydantic.fields.FieldInfo) assert User.Meta.model_fields["name"].max_length == 100 assert isinstance(User.Meta.table, sqlalchemy.Table) diff --git a/tests/test_new_annotation_style.py b/tests/test_new_annotation_style.py new file mode 100644 index 0000000..9a665cc --- /dev/null +++ b/tests/test_new_annotation_style.py @@ -0,0 +1,374 @@ +from typing import Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Album(ormar.Model): + class Meta: + tablename = "albums" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Track(ormar.Model): + class Meta: + tablename = "tracks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Optional[Album] = ormar.ForeignKey(Album) + title: str = ormar.String(max_length=100) + position: int = ormar.Integer() + + +class Cover(ormar.Model): + class Meta: + tablename = "covers" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + album: Album = ormar.ForeignKey(Album, related_name="cover_pictures") + title: str = ormar.String(max_length=100) + + +class Organisation(ormar.Model): + class Meta: + tablename = "org" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"]) + + +class Team(ormar.Model): + class Meta: + tablename = "teams" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + org: Optional[Organisation] = ormar.ForeignKey(Organisation) + name: str = ormar.String(max_length=100) + + +class Member(ormar.Model): + class Meta: + tablename = "members" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + team: Optional[Team] = ormar.ForeignKey(Team) + email: str = ormar.String(max_length=100) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_wrong_query_foreign_key_type(): + async with database: + with pytest.raises(RelationshipInstanceError): + Track(title="The Error", album="wrong_pk_type") + + +@pytest.mark.asyncio +async def test_setting_explicitly_empty_relation(): + async with database: + track = Track(album=None, title="The Bird", position=1) + assert track.album is None + + +@pytest.mark.asyncio +async def test_related_name(): + async with database: + async with database.transaction(force_rollback=True): + album = await Album.objects.create(name="Vanilla") + await Cover.objects.create(album=album, title="The cover file") + assert len(album.cover_pictures) == 1 + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + async with database.transaction(force_rollback=True): + album = Album(name="Jamaica") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() + + track = await Track.objects.get(title="The Bird") + assert track.album.pk == album.pk + assert isinstance(track.album, ormar.Model) + assert track.album.name is None + await track.album.load() + assert track.album.name == "Jamaica" + + assert len(album.tracks) == 3 + assert album.tracks[1].title == "Heart don't stand a chance" + + album1 = await Album.objects.get(name="Jamaica") + assert album1.pk == album.pk + assert album1.tracks == [] + + await Track.objects.create( + album={"id": track.album.pk}, title="The Bird2", position=4 + ) + + +@pytest.mark.asyncio +async def test_select_related(): + async with database: + async with database.transaction(force_rollback=True): + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() + + fantasies = Album(name="Fantasies") + await fantasies.save() + track4 = Track(album=fantasies, title="Help I'm Alive", position=1) + track5 = Track(album=fantasies, title="Sick Muse", position=2) + track6 = Track(album=fantasies, title="Satellite Mind", position=3) + await track4.save() + await track5.save() + await track6.save() + + track = await Track.objects.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" + + tracks = await Track.objects.select_related("album").all() + assert len(tracks) == 6 + + +@pytest.mark.asyncio +async def test_model_removal_from_relations(): + async with database: + async with database.transaction(force_rollback=True): + album = Album(name="Chichi") + await album.save() + track1 = Track(album=album, title="The Birdman", position=1) + track2 = Track(album=album, title="Superman", position=2) + track3 = Track(album=album, title="Wonder Woman", position=3) + await track1.save() + await track2.save() + await track3.save() + + assert len(album.tracks) == 3 + await album.tracks.remove(track1) + assert len(album.tracks) == 2 + assert track1.album is None + + await track1.update() + track1 = await Track.objects.get(title="The Birdman") + assert track1.album is None + + await album.tracks.add(track1) + assert len(album.tracks) == 3 + assert track1.album == album + + await track1.update() + track1 = await Track.objects.select_related("album__tracks").get( + title="The Birdman" + ) + album = await Album.objects.select_related("tracks").get(name="Chichi") + assert track1.album == album + + track1.remove(album) + assert track1.album is None + assert len(album.tracks) == 2 + + track2.remove(album) + assert track2.album is None + assert len(album.tracks) == 1 + + +@pytest.mark.asyncio +async def test_fk_filter(): + async with database: + async with database.transaction(force_rollback=True): + malibu = Album(name="Malibu%") + await malibu.save() + await Track.objects.create(album=malibu, title="The Bird", position=1) + await Track.objects.create( + album=malibu, title="Heart don't stand a chance", position=2 + ) + await Track.objects.create(album=malibu, title="The Waters", position=3) + + fantasies = await Album.objects.create(name="Fantasies") + await Track.objects.create( + album=fantasies, title="Help I'm Alive", position=1 + ) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create( + album=fantasies, title="Satellite Mind", position=3 + ) + + tracks = ( + await Track.objects.select_related("album") + .filter(album__name="Fantasies") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = ( + await Track.objects.select_related("album") + .filter(album__name__icontains="fan") + .all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__contains="Fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__contains="Malibu%").all() + assert len(tracks) == 3 + + tracks = ( + await Track.objects.filter(album=malibu).select_related("album").all() + ) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" + + tracks = await Track.objects.select_related("album").all(album=malibu) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" + + +@pytest.mark.asyncio +async def test_multiple_fk(): + async with database: + async with database.transaction(force_rollback=True): + acme = await Organisation.objects.create(ident="ACME Ltd") + red_team = await Team.objects.create(org=acme, name="Red Team") + blue_team = await Team.objects.create(org=acme, name="Blue Team") + await Member.objects.create(team=red_team, email="a@example.org") + await Member.objects.create(team=red_team, email="b@example.org") + await Member.objects.create(team=blue_team, email="c@example.org") + await Member.objects.create(team=blue_team, email="d@example.org") + + other = await Organisation.objects.create(ident="Other ltd") + team = await Team.objects.create(org=other, name="Green Team") + await Member.objects.create(team=team, email="e@example.org") + + members = ( + await Member.objects.select_related("team__org") + .filter(team__org__ident="ACME Ltd") + .all() + ) + assert len(members) == 4 + for member in members: + assert member.team.org.ident == "ACME Ltd" + + +@pytest.mark.asyncio +async def test_wrong_choices(): + async with database: + async with database.transaction(force_rollback=True): + with pytest.raises(ValueError): + await Organisation.objects.create(ident="Test 1") + + +@pytest.mark.asyncio +async def test_pk_filter(): + async with database: + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Test") + track = await Track.objects.create( + album=fantasies, title="Test1", position=1 + ) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + tracks = ( + await Track.objects.select_related("album").filter(pk=track.pk).all() + ) + assert len(tracks) == 1 + + tracks = ( + await Track.objects.select_related("album") + .filter(position=2, album__name="Test") + .all() + ) + assert len(tracks) == 1 + + +@pytest.mark.asyncio +async def test_limit_and_offset(): + async with database: + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Limitless") + await Track.objects.create( + id=None, album=fantasies, title="Sample", position=1 + ) + await Track.objects.create(album=fantasies, title="Sample2", position=2) + await Track.objects.create(album=fantasies, title="Sample3", position=3) + + tracks = await Track.objects.limit(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample" + + tracks = await Track.objects.limit(1).offset(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample2" + + +@pytest.mark.asyncio +async def test_get_exceptions(): + async with database: + async with database.transaction(force_rollback=True): + fantasies = await Album.objects.create(name="Test") + + with pytest.raises(NoMatch): + await Album.objects.get(name="Test2") + + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + with pytest.raises(MultipleMatches): + await Track.objects.select_related("album").get(album=fantasies) + + +@pytest.mark.asyncio +async def test_wrong_model_passed_as_fk(): + async with database: + async with database.transaction(force_rollback=True): + with pytest.raises(RelationshipInstanceError): + org = await Organisation.objects.create(ident="ACME Ltd") + await Track.objects.create(album=org, title="Test1", position=1) diff --git a/tests/test_server_default.py b/tests/test_server_default.py index b9106d0..4c4db66 100644 --- a/tests/test_server_default.py +++ b/tests/test_server_default.py @@ -22,7 +22,7 @@ class Product(ormar.Model): id: ormar.Integer(primary_key=True) name: ormar.String(max_length=100) - company: ormar.String(max_length=200, server_default='Acme') + company: ormar.String(max_length=200, server_default="Acme") sort_order: ormar.Integer(server_default=text("10")) created: ormar.DateTime(server_default=func.now()) @@ -44,42 +44,44 @@ async def create_test_database(): def test_table_defined_properly(): - assert Product.Meta.model_fields['created'].nullable - assert not Product.__fields__['created'].required - assert Product.Meta.table.columns['created'].server_default.arg.name == 'now' + assert Product.Meta.model_fields["created"].nullable + assert not Product.__fields__["created"].required + assert Product.Meta.table.columns["created"].server_default.arg.name == "now" @pytest.mark.asyncio async def test_model_creation(): async with database: async with database.transaction(force_rollback=True): - p1 = Product(name='Test') + p1 = Product(name="Test") assert p1.created is None await p1.save() await p1.load() assert p1.created is not None - assert p1.company == 'Acme' + assert p1.company == "Acme" assert p1.sort_order == 10 - date = datetime.strptime('2020-10-27 11:30', '%Y-%m-%d %H:%M') - p3 = await Product.objects.create(name='Test2', created=date, company='Roadrunner', sort_order=1) + date = datetime.strptime("2020-10-27 11:30", "%Y-%m-%d %H:%M") + p3 = await Product.objects.create( + name="Test2", created=date, company="Roadrunner", sort_order=1 + ) assert p3.created is not None assert p3.created == date assert p1.created != p3.created - assert p3.company == 'Roadrunner' + assert p3.company == "Roadrunner" assert p3.sort_order == 1 - p3 = await Product.objects.get(name='Test2') - assert p3.company == 'Roadrunner' + p3 = await Product.objects.get(name="Test2") + assert p3.company == "Roadrunner" assert p3.sort_order == 1 time.sleep(1) - p2 = await Product.objects.create(name='Test3') + p2 = await Product.objects.create(name="Test3") assert p2.created is not None - assert p2.company == 'Acme' + assert p2.company == "Acme" assert p2.sort_order == 10 - if Product.db_backend_name() != 'postgresql': + if Product.db_backend_name() != "postgresql": # postgres use transaction timestamp so it will remain the same assert p1.created != p2.created # pragma nocover diff --git a/tests/test_unique_constraints.py b/tests/test_unique_constraints.py index ffa561c..93dd9ba 100644 --- a/tests/test_unique_constraints.py +++ b/tests/test_unique_constraints.py @@ -1,7 +1,7 @@ import asyncio import sqlite3 -import asyncpg # type: ignore +import asyncpg # type: ignore import databases import pymysql import pytest