allow change to build in type hints

This commit is contained in:
collerek
2020-10-31 15:43:34 +01:00
parent 320588a3c1
commit 8fba94efa1
18 changed files with 575 additions and 131 deletions

View File

@ -1,5 +1,6 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
import pydantic
import sqlalchemy
from pydantic import Field, typing
from pydantic.fields import FieldInfo
@ -11,8 +12,9 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import NewBaseModel
class BaseField:
class BaseField(FieldInfo):
__type__ = None
__pydantic_type__ = None
column_type: sqlalchemy.Column
constraints: List = []
@ -32,6 +34,28 @@ class BaseField:
default: Any
server_default: Any
@classmethod
def is_valid_field_info_field(cls, field_name: str) -> bool:
return (
field_name not in ["default", "default_factory"]
and not field_name.startswith("__")
and hasattr(cls, field_name)
)
@classmethod
def convert_to_pydantic_field_info(cls, allow_null: bool = False) -> FieldInfo:
base = cls.default_value()
if base is None:
base = (
FieldInfo(default=None)
if (cls.nullable or allow_null)
else FieldInfo(default=pydantic.fields.Undefined)
)
for attr_name in FieldInfo.__dict__.keys():
if cls.is_valid_field_info_field(attr_name):
setattr(base, attr_name, cls.__dict__.get(attr_name))
return base
@classmethod
def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]:
if cls.is_auto_primary_key():

View File

@ -1,4 +1,4 @@
from typing import Any, Generator, List, Optional, TYPE_CHECKING, Type, Union
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy
from sqlalchemy import UniqueConstraint
@ -39,8 +39,15 @@ def ForeignKey( # noqa CFQ002
ondelete: str = None,
) -> Type["ForeignKeyField"]:
fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname)
to_field = to.__fields__[to.Meta.pkname]
to_field = to.Meta.model_fields[to.Meta.pkname]
__type__ = (
Union[to_field.__type__, to]
if not nullable
else Optional[Union[to_field.__type__, to]]
)
namespace = dict(
__type__=__type__,
__pydantic_type__=__type__,
to=to,
name=name,
nullable=nullable,
@ -50,7 +57,7 @@ def ForeignKey( # noqa CFQ002
)
],
unique=unique,
column_type=to_field.type_.column_type,
column_type=to_field.column_type,
related_name=related_name,
virtual=virtual,
primary_key=False,
@ -58,7 +65,6 @@ def ForeignKey( # noqa CFQ002
pydantic_only=False,
default=None,
server_default=None,
__pydantic_model__=to,
)
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
@ -70,14 +76,6 @@ class ForeignKeyField(BaseField):
related_name: str
virtual: bool
@classmethod
def __get_validators__(cls) -> Generator:
yield cls.validate
@classmethod
def validate(cls, value: Any) -> Any:
return value
@classmethod
def _extract_model_from_sequence(
cls, value: List, child: "Model", to_register: bool

View File

@ -1,4 +1,4 @@
from typing import Dict, TYPE_CHECKING, Type
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
@ -15,17 +15,26 @@ def ManyToMany(
*,
name: str = None,
unique: bool = False,
related_name: str = None,
virtual: bool = False,
**kwargs: Any
) -> Type["ManyToManyField"]:
to_field = to.__fields__[to.Meta.pkname]
to_field = to.Meta.model_fields[to.Meta.pkname]
related_name = kwargs.pop("related_name", None)
nullable = kwargs.pop("nullable", True)
__type__ = (
Union[to_field.__type__, to, List[to]] # type: ignore
if not nullable
else Optional[Union[to_field.__type__, to, List[to]]] # type: ignore
)
namespace = dict(
__type__=__type__,
__pydantic_type__=__type__,
to=to,
through=through,
name=name,
nullable=True,
unique=unique,
column_type=to_field.type_.column_type,
column_type=to_field.column_type,
related_name=related_name,
virtual=virtual,
primary_key=False,
@ -33,9 +42,6 @@ def ManyToMany(
pydantic_only=False,
default=None,
server_default=None,
__pydantic_model__=to,
# __origin__=List,
# __args__=[Optional[to]]
)
return type("ManyToMany", (ManyToManyField, BaseField), namespace)
@ -43,10 +49,3 @@ def ManyToMany(
class ManyToManyField(ForeignKeyField):
through: Type["Model"]
@classmethod
def __modify_schema__(cls, field_schema: Dict) -> None:
field_schema["type"] = "array"
field_schema["title"] = cls.name.title()
field_schema["definitions"] = {f"{cls.to.__name__}": cls.to.schema()}
field_schema["items"] = {"$ref": f"{REF_PREFIX}{cls.to.__name__}"}

View File

@ -20,8 +20,9 @@ def is_field_nullable(
class ModelFieldFactory:
_bases: Any = BaseField
_bases: Any = (BaseField,)
_type: Any = None
_pydantic_type: Any = None
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore
cls.validate(**kwargs)
@ -32,6 +33,7 @@ class ModelFieldFactory:
namespace = dict(
__type__=cls._type,
__pydantic_type__=cls._pydantic_type,
name=kwargs.pop("name", None),
primary_key=kwargs.pop("primary_key", False),
default=default,
@ -57,8 +59,8 @@ class ModelFieldFactory:
class String(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
_pydantic_type = pydantic.ConstrainedStr
def __new__( # type: ignore # noqa CFQ002
cls,
@ -96,8 +98,8 @@ class String(ModelFieldFactory):
class Integer(ModelFieldFactory):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int
_pydantic_type = pydantic.ConstrainedInt
def __new__( # type: ignore
cls,
@ -131,8 +133,8 @@ class Integer(ModelFieldFactory):
class Text(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
_pydantic_type = pydantic.ConstrainedStr
def __new__( # type: ignore
cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any
@ -154,8 +156,8 @@ class Text(ModelFieldFactory):
class Float(ModelFieldFactory):
_bases = (pydantic.ConstrainedFloat, BaseField)
_type = float
_pydantic_type = pydantic.ConstrainedFloat
def __new__( # type: ignore
cls,
@ -183,8 +185,8 @@ class Float(ModelFieldFactory):
class Boolean(ModelFieldFactory):
_bases = (int, BaseField)
_type = bool
_pydantic_type = bool
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
@ -192,8 +194,8 @@ class Boolean(ModelFieldFactory):
class DateTime(ModelFieldFactory):
_bases = (datetime.datetime, BaseField)
_type = datetime.datetime
_pydantic_type = datetime.datetime
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
@ -201,8 +203,8 @@ class DateTime(ModelFieldFactory):
class Date(ModelFieldFactory):
_bases = (datetime.date, BaseField)
_type = datetime.date
_pydantic_type = datetime.date
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
@ -210,8 +212,8 @@ class Date(ModelFieldFactory):
class Time(ModelFieldFactory):
_bases = (datetime.time, BaseField)
_type = datetime.time
_pydantic_type = datetime.time
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
@ -219,8 +221,8 @@ class Time(ModelFieldFactory):
class JSON(ModelFieldFactory):
_bases = (pydantic.Json, BaseField)
_type = pydantic.Json
_pydantic_type = pydantic.Json
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
@ -228,8 +230,8 @@ class JSON(ModelFieldFactory):
class BigInteger(Integer):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int
_pydantic_type = pydantic.ConstrainedInt
def __new__( # type: ignore
cls,
@ -263,8 +265,8 @@ class BigInteger(Integer):
class Decimal(ModelFieldFactory):
_bases = (pydantic.ConstrainedDecimal, BaseField)
_type = decimal.Decimal
_pydantic_type = pydantic.ConstrainedDecimal
def __new__( # type: ignore # noqa CFQ002
cls,
@ -318,8 +320,8 @@ class Decimal(ModelFieldFactory):
class UUID(ModelFieldFactory):
_bases = (uuid.UUID, BaseField)
_type = uuid.UUID
_pydantic_type = uuid.UUID
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:

View File

@ -5,7 +5,8 @@ import databases
import pydantic
import sqlalchemy
from pydantic import BaseConfig
from pydantic.fields import FieldInfo, ModelField
from pydantic.fields import ModelField
from pydantic.utils import lenient_issubclass
from sqlalchemy.sql.schema import ColumnCollectionConstraint
import ormar # noqa I100
@ -179,44 +180,58 @@ def register_relation_in_alias_manager(
def populate_default_pydantic_field_value(
type_: Type[BaseField], field: str, attrs: dict
ormar_field: Type[BaseField], field_name: str, attrs: dict
) -> dict:
def_value = type_.default_value()
curr_def_value = attrs.get(field, "NONE")
if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
attrs[field] = def_value
elif curr_def_value == "NONE" and type_.nullable:
attrs[field] = FieldInfo(default=None)
curr_def_value = attrs.get(field_name, ormar.Undefined)
if lenient_issubclass(curr_def_value, ormar.fields.BaseField):
curr_def_value = ormar.Undefined
if curr_def_value is None:
attrs[field_name] = ormar_field.convert_to_pydantic_field_info(allow_null=True)
else:
attrs[field_name] = ormar_field.convert_to_pydantic_field_info()
return attrs
def populate_pydantic_default_values(attrs: Dict) -> Dict:
for field, type_ in attrs["__annotations__"].items():
if issubclass(type_, BaseField):
if type_.name is None:
type_.name = field
attrs = populate_default_pydantic_field_value(type_, field, attrs)
return attrs
def check_if_field_annotation_or_value_is_ormar(
field: Any, field_name: str, attrs: Dict
) -> bool:
return lenient_issubclass(field, BaseField) or issubclass(
attrs.get(field_name, type), BaseField
)
def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
def extract_field_from_annotation_or_value(
field: Any, field_name: str, attrs: Dict
) -> Type[ormar.fields.BaseField]:
return field if lenient_issubclass(field, BaseField) else attrs.get(field_name)
def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
model_fields = {}
for field_name, field in attrs["__annotations__"].items():
# ormar fields can be used as annotation or as default value
if check_if_field_annotation_or_value_is_ormar(field, field_name, attrs):
ormar_field = extract_field_from_annotation_or_value(
field, field_name, attrs
)
if ormar_field.name is None:
ormar_field.name = field_name
attrs = populate_default_pydantic_field_value(
ormar_field, field_name, attrs
)
model_fields[field_name] = ormar_field
attrs["__annotations__"][field_name] = ormar_field.__type__
return attrs, model_fields
def extract_annotations_and_default_vals(
attrs: dict, bases: Tuple
) -> Tuple[Dict, Dict]:
attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get(
"__annotations__", {}
)
attrs = populate_pydantic_default_values(attrs)
return attrs
def populate_meta_orm_model_fields(
attrs: dict, new_model: Type["Model"]
) -> Type["Model"]:
model_fields = {
field_name: field
for field_name, field in attrs["__annotations__"].items()
if issubclass(field, BaseField)
}
new_model.Meta.model_fields = model_fields
return new_model
attrs, model_fields = populate_pydantic_default_values(attrs)
return attrs, model_fields
def populate_meta_tablename_columns_and_pk(
@ -305,7 +320,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
) -> "ModelMetaclass":
attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name
attrs = extract_annotations_and_default_vals(attrs, bases)
attrs, model_fields = extract_annotations_and_default_vals(attrs, bases)
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
@ -313,7 +328,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
if hasattr(new_model, "Meta"):
if not hasattr(new_model.Meta, "constraints"):
new_model.Meta.constraints = []
new_model = populate_meta_orm_model_fields(attrs, new_model)
if not hasattr(new_model.Meta, "model_fields"):
new_model.Meta.model_fields = model_fields
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
expand_reverse_relationships(new_model)
@ -322,7 +338,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname
field = Integer(name=field_name, primary_key=True)
attrs["__annotations__"][field_name] = field
attrs["__annotations__"][field_name] = Optional[int] # type: ignore
populate_default_pydantic_field_value(
field, field_name, attrs # type: ignore
)

View File

@ -1,11 +1,12 @@
import itertools
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
import sqlalchemy
import ormar.queryset # noqa I100
from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100
from ormar.models.metaclass import ModelMeta
def group_related_list(list_: List) -> Dict:
@ -23,18 +24,30 @@ def group_related_list(list_: List) -> Dict:
return test_dict
T = TypeVar("T", bound="Model")
class Model(NewBaseModel):
__abstract__ = False
if TYPE_CHECKING: # pragma nocover
Meta: ModelMeta
def __repr__(self) -> str: # pragma nocover
attrs_to_include = ["tablename", "columns", "pkname"]
_repr = {k: v for k, v in self.Meta.model_fields.items()}
for atr in attrs_to_include:
_repr[atr] = getattr(self.Meta, atr)
return f"{self.__class__.__name__}({str(_repr)})"
@classmethod
def from_row( # noqa CCR001
cls,
cls: Type[T],
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
related_models: Any = None,
previous_table: str = None,
fields: List = None,
) -> Optional["Model"]:
) -> Optional[T]:
item: Dict[str, Any] = {}
select_related = select_related or []
@ -66,7 +79,9 @@ class Model(NewBaseModel):
item, row, table_prefix, fields, nested=table_prefix != ""
)
instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None
instance: Optional[T] = cls(**item) if item.get(
cls.Meta.pkname, None
) is not None else None
return instance
@classmethod
@ -124,7 +139,7 @@ class Model(NewBaseModel):
return item
async def save(self) -> "Model":
async def save(self: T) -> T:
self_fields = self._extract_model_db_fields()
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
@ -137,7 +152,7 @@ class Model(NewBaseModel):
setattr(self, self.Meta.pkname, item_id)
return self
async def update(self, **kwargs: Any) -> "Model":
async def update(self: T, **kwargs: Any) -> T:
if kwargs:
new_values = {**self.dict(), **kwargs}
self.from_dict(new_values)
@ -151,13 +166,13 @@ class Model(NewBaseModel):
await self.Meta.database.execute(expr)
return self
async def delete(self) -> int:
async def delete(self: T) -> int:
expr = self.Meta.table.delete()
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
result = await self.Meta.database.execute(expr)
return result
async def load(self) -> "Model":
async def load(self: T) -> T:
expr = self.Meta.table.select().where(self.pk_column == self.pk)
row = await self.Meta.database.fetch_one(expr)
if not row: # pragma nocover

View File

@ -1,5 +1,5 @@
import inspect
from typing import Dict, List, Set, TYPE_CHECKING, Type, TypeVar, Union
from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union
import ormar
from ormar.exceptions import RelationshipInstanceError
@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models import NewBaseModel
T = TypeVar("T", bound=Model)
Field = TypeVar("Field", bound=BaseField)
@ -135,7 +137,7 @@ class ModelTableProxy:
if field.to == related.__class__ or field.to.Meta == related.Meta:
return name
# fallback for not registered relation
if register_missing:
if register_missing: # pragma nocover
expand_reverse_relationships(related.__class__) # type: ignore
return ModelTableProxy.resolve_relation_name(
item, related, register_missing=False
@ -177,7 +179,7 @@ class ModelTableProxy:
return new_kwargs
@classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
def merge_instances_list(cls, result_rows: Sequence["Model"]) -> Sequence["Model"]:
merged_rows: List["Model"] = []
for index, model in enumerate(result_rows):
if index > 0 and model is not None and model.pk == merged_rows[-1].pk:

View File

@ -5,11 +5,12 @@ from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
TYPE_CHECKING,
Type,
TypeVar,
Union,
)
@ -27,7 +28,9 @@ from ormar.relations.alias_manager import AliasManager
from ormar.relations.relation_manager import RelationsManager
if TYPE_CHECKING: # pragma no cover
from ormar.models.model import Model
from ormar import Model
T = TypeVar("T", bound=Model)
IntStr = Union[int, str]
DictStrAny = Dict[str, Any]
@ -52,7 +55,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
Meta: ModelMeta
# noinspection PyMissingConstructor
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False)
@ -73,7 +76,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register
kwargs = {
new_kwargs = {
k: self._convert_json(
k,
self.Meta.model_fields[k].expand_relationship(
@ -85,7 +88,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
}
values, fields_set, validation_error = pydantic.validate_model(
self, kwargs # type: ignore
self, new_kwargs # type: ignore
)
if validation_error and not pk_only:
raise validation_error
@ -96,7 +99,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
# register the columns models after initialization
for related in self.extract_related_names():
self.Meta.model_fields[related].expand_relationship(
kwargs.get(related), self, to_register=True
new_kwargs.get(related), self, to_register=True
)
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
@ -133,7 +136,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def _extract_related_model_instead_of_field(
self, item: str
) -> Optional[Union["Model", List["Model"]]]:
) -> Optional[Union[T, Sequence[T]]]:
alias = self.get_column_alias(item)
if alias in self._orm:
return self._orm.get(alias)
@ -170,7 +173,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def db_backend_name(cls) -> str:
return cls.Meta.database._backend._dialect.name
def remove(self, name: "Model") -> None:
def remove(self, name: T) -> None:
self._orm.remove_parent(self, name)
def dict( # noqa A003

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union
import databases
import sqlalchemy
@ -59,7 +59,7 @@ class QuerySet:
raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
result_rows = [
self.model.from_row(
row, select_related=self._select_related, fields=self._columns
@ -87,7 +87,7 @@ class QuerySet:
return new_kwargs
@staticmethod
def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None:
def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None:
if not rows or rows[0] is None:
raise NoMatch()
if len(rows) > 1:
@ -267,7 +267,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
if kwargs:
return await self.filter(**kwargs).all()

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, TYPE_CHECKING, Union
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union
import ormar
@ -7,6 +7,8 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
from ormar.queryset import QuerySet
T = TypeVar("T", bound=Model)
class QuerysetProxy:
if TYPE_CHECKING: # pragma no cover
@ -26,27 +28,28 @@ class QuerysetProxy:
def queryset(self, value: "QuerySet") -> None:
self._queryset = value
def _assign_child_to_parent(self, child: Optional["Model"]) -> None:
def _assign_child_to_parent(self, child: Optional[T]) -> None:
if child:
owner = self.relation._owner
rel_name = owner.resolve_relation_name(owner, child)
setattr(owner, rel_name, child)
def _register_related(self, child: Union["Model", List[Optional["Model"]]]) -> None:
def _register_related(self, child: Union[T, Sequence[Optional[T]]]) -> None:
if isinstance(child, list):
for subchild in child:
self._assign_child_to_parent(subchild)
else:
assert isinstance(child, Model)
self._assign_child_to_parent(child)
async def create_through_instance(self, child: "Model") -> None:
async def create_through_instance(self, child: T) -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child}
await queryset.create(**kwargs)
async def delete_through_instance(self, child: "Model") -> None:
async def delete_through_instance(self, child: T) -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name()
child_column = child.get_name()
@ -88,7 +91,7 @@ class QuerysetProxy:
self._register_related(get)
return get
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
all_items = await self.queryset.all(**kwargs)
self._register_related(all_items)
return all_items

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional, TYPE_CHECKING, Type, Union
from typing import List, Optional, TYPE_CHECKING, Type, TypeVar, Union
import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100
@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover
from ormar.relations import RelationsManager
from ormar.models import NewBaseModel
T = TypeVar("T", bound=Model)
class RelationType(Enum):
PRIMARY = 1
@ -23,15 +25,15 @@ class Relation:
self,
manager: "RelationsManager",
type_: RelationType,
to: Type["Model"],
through: Type["Model"] = None,
to: Type[T],
through: Type[T] = None,
) -> None:
self.manager = manager
self._owner: "Model" = manager.owner
self._type: RelationType = type_
self.to: Type["Model"] = to
self.through: Optional[Type["Model"]] = through
self.related_models: Optional[Union[RelationProxy, "Model"]] = (
self.to: Type[T] = to
self.through: Optional[Type[T]] = through
self.related_models: Optional[Union[RelationProxy, T]] = (
RelationProxy(relation=self)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None
@ -50,7 +52,7 @@ class Relation:
self.related_models.pop(ind)
return None
def add(self, child: "Model") -> None:
def add(self, child: T) -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
self.related_models = child
@ -77,7 +79,7 @@ class Relation:
self.related_models.pop(position) # type: ignore
del self._owner.__dict__[relation_name][position]
def get(self) -> Optional[Union[List["Model"], "Model"]]:
def get(self) -> Optional[Union[List[T], T]]:
return self.related_models
def __repr__(self) -> str: # pragma no cover

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, TYPE_CHECKING, Type, Union
from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, TypeVar, Union
from weakref import proxy
from ormar.fields import BaseField
@ -14,6 +14,8 @@ if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models import NewBaseModel
T = TypeVar("T", bound=Model)
class RelationsManager:
def __init__(
@ -46,7 +48,7 @@ class RelationsManager:
def __contains__(self, item: str) -> bool:
return item in self._related_names
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
def get(self, name: str) -> Optional[Union[T, Sequence[T]]]:
relation = self._relations.get(name, None)
if relation is not None:
return relation.get()

View File

@ -72,6 +72,6 @@ class RelationProxy(list):
if self.relation._type == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner)
if rel_name not in item._orm:
if rel_name not in item._orm: # pragma nocover
item._orm._add_relation(item.Meta.model_fields[rel_name])
setattr(item, rel_name, self._owner)

View File

@ -125,7 +125,9 @@ def test_all_endpoints():
def test_schema_modification():
schema = Item.schema()
assert schema["properties"]["categories"]["type"] == "array"
assert any(
x.get("type") == "array" for x in schema["properties"]["categories"]["anyOf"]
)
assert schema["properties"]["categories"]["title"] == "Categories"

View File

@ -98,9 +98,9 @@ async def create_test_database():
def test_model_class():
assert list(User.Meta.model_fields.keys()) == ["id", "name"]
assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt)
assert issubclass(User.Meta.model_fields["id"], pydantic.fields.FieldInfo)
assert User.Meta.model_fields["id"].primary_key is True
assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr)
assert issubclass(User.Meta.model_fields["name"], pydantic.fields.FieldInfo)
assert User.Meta.model_fields["name"].max_length == 100
assert isinstance(User.Meta.table, sqlalchemy.Table)

View File

@ -0,0 +1,374 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Album = ormar.ForeignKey(Album, related_name="cover_pictures")
title: str = ormar.String(max_length=100)
class Organisation(ormar.Model):
class Meta:
tablename = "org"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
class Team(ormar.Model):
class Meta:
tablename = "teams"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
org: Optional[Organisation] = ormar.ForeignKey(Organisation)
name: str = ormar.String(max_length=100)
class Member(ormar.Model):
class Meta:
tablename = "members"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
team: Optional[Team] = ormar.ForeignKey(Team)
email: str = ormar.String(max_length=100)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_wrong_query_foreign_key_type():
async with database:
with pytest.raises(RelationshipInstanceError):
Track(title="The Error", album="wrong_pk_type")
@pytest.mark.asyncio
async def test_setting_explicitly_empty_relation():
async with database:
track = Track(album=None, title="The Bird", position=1)
assert track.album is None
@pytest.mark.asyncio
async def test_related_name():
async with database:
async with database.transaction(force_rollback=True):
album = await Album.objects.create(name="Vanilla")
await Cover.objects.create(album=album, title="The cover file")
assert len(album.cover_pictures) == 1
@pytest.mark.asyncio
async def test_model_crud():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Jamaica")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model)
assert track.album.name is None
await track.album.load()
assert track.album.name == "Jamaica"
assert len(album.tracks) == 3
assert album.tracks[1].title == "Heart don't stand a chance"
album1 = await Album.objects.get(name="Jamaica")
assert album1.pk == album.pk
assert album1.tracks == []
await Track.objects.create(
album={"id": track.album.pk}, title="The Bird2", position=4
)
@pytest.mark.asyncio
async def test_select_related():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Malibu")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
fantasies = Album(name="Fantasies")
await fantasies.save()
track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
track5 = Track(album=fantasies, title="Sick Muse", position=2)
track6 = Track(album=fantasies, title="Satellite Mind", position=3)
await track4.save()
await track5.save()
await track6.save()
track = await Track.objects.select_related("album").get(title="The Bird")
assert track.album.name == "Malibu"
tracks = await Track.objects.select_related("album").all()
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_model_removal_from_relations():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Chichi")
await album.save()
track1 = Track(album=album, title="The Birdman", position=1)
track2 = Track(album=album, title="Superman", position=2)
track3 = Track(album=album, title="Wonder Woman", position=3)
await track1.save()
await track2.save()
await track3.save()
assert len(album.tracks) == 3
await album.tracks.remove(track1)
assert len(album.tracks) == 2
assert track1.album is None
await track1.update()
track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None
await album.tracks.add(track1)
assert len(album.tracks) == 3
assert track1.album == album
await track1.update()
track1 = await Track.objects.select_related("album__tracks").get(
title="The Birdman"
)
album = await Album.objects.select_related("tracks").get(name="Chichi")
assert track1.album == album
track1.remove(album)
assert track1.album is None
assert len(album.tracks) == 2
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
@pytest.mark.asyncio
async def test_fk_filter():
async with database:
async with database.transaction(force_rollback=True):
malibu = Album(name="Malibu%")
await malibu.save()
await Track.objects.create(album=malibu, title="The Bird", position=1)
await Track.objects.create(
album=malibu, title="Heart don't stand a chance", position=2
)
await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies")
await Track.objects.create(
album=fantasies, title="Help I'm Alive", position=1
)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(
album=fantasies, title="Satellite Mind", position=3
)
tracks = (
await Track.objects.select_related("album")
.filter(album__name="Fantasies")
.all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = (
await Track.objects.select_related("album")
.filter(album__name__icontains="fan")
.all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Fan").all()
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
assert len(tracks) == 3
tracks = (
await Track.objects.filter(album=malibu).select_related("album").all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
tracks = await Track.objects.select_related("album").all(album=malibu)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
@pytest.mark.asyncio
async def test_multiple_fk():
async with database:
async with database.transaction(force_rollback=True):
acme = await Organisation.objects.create(ident="ACME Ltd")
red_team = await Team.objects.create(org=acme, name="Red Team")
blue_team = await Team.objects.create(org=acme, name="Blue Team")
await Member.objects.create(team=red_team, email="a@example.org")
await Member.objects.create(team=red_team, email="b@example.org")
await Member.objects.create(team=blue_team, email="c@example.org")
await Member.objects.create(team=blue_team, email="d@example.org")
other = await Organisation.objects.create(ident="Other ltd")
team = await Team.objects.create(org=other, name="Green Team")
await Member.objects.create(team=team, email="e@example.org")
members = (
await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd")
.all()
)
assert len(members) == 4
for member in members:
assert member.team.org.ident == "ACME Ltd"
@pytest.mark.asyncio
async def test_wrong_choices():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(ValueError):
await Organisation.objects.create(ident="Test 1")
@pytest.mark.asyncio
async def test_pk_filter():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
track = await Track.objects.create(
album=fantasies, title="Test1", position=1
)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
tracks = (
await Track.objects.select_related("album").filter(pk=track.pk).all()
)
assert len(tracks) == 1
tracks = (
await Track.objects.select_related("album")
.filter(position=2, album__name="Test")
.all()
)
assert len(tracks) == 1
@pytest.mark.asyncio
async def test_limit_and_offset():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Limitless")
await Track.objects.create(
id=None, album=fantasies, title="Sample", position=1
)
await Track.objects.create(album=fantasies, title="Sample2", position=2)
await Track.objects.create(album=fantasies, title="Sample3", position=3)
tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample"
tracks = await Track.objects.limit(1).offset(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample2"
@pytest.mark.asyncio
async def test_get_exceptions():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
with pytest.raises(NoMatch):
await Album.objects.get(name="Test2")
await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
with pytest.raises(MultipleMatches):
await Track.objects.select_related("album").get(album=fantasies)
@pytest.mark.asyncio
async def test_wrong_model_passed_as_fk():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(RelationshipInstanceError):
org = await Organisation.objects.create(ident="ACME Ltd")
await Track.objects.create(album=org, title="Test1", position=1)

View File

@ -22,7 +22,7 @@ class Product(ormar.Model):
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
company: ormar.String(max_length=200, server_default='Acme')
company: ormar.String(max_length=200, server_default="Acme")
sort_order: ormar.Integer(server_default=text("10"))
created: ormar.DateTime(server_default=func.now())
@ -44,42 +44,44 @@ async def create_test_database():
def test_table_defined_properly():
assert Product.Meta.model_fields['created'].nullable
assert not Product.__fields__['created'].required
assert Product.Meta.table.columns['created'].server_default.arg.name == 'now'
assert Product.Meta.model_fields["created"].nullable
assert not Product.__fields__["created"].required
assert Product.Meta.table.columns["created"].server_default.arg.name == "now"
@pytest.mark.asyncio
async def test_model_creation():
async with database:
async with database.transaction(force_rollback=True):
p1 = Product(name='Test')
p1 = Product(name="Test")
assert p1.created is None
await p1.save()
await p1.load()
assert p1.created is not None
assert p1.company == 'Acme'
assert p1.company == "Acme"
assert p1.sort_order == 10
date = datetime.strptime('2020-10-27 11:30', '%Y-%m-%d %H:%M')
p3 = await Product.objects.create(name='Test2', created=date, company='Roadrunner', sort_order=1)
date = datetime.strptime("2020-10-27 11:30", "%Y-%m-%d %H:%M")
p3 = await Product.objects.create(
name="Test2", created=date, company="Roadrunner", sort_order=1
)
assert p3.created is not None
assert p3.created == date
assert p1.created != p3.created
assert p3.company == 'Roadrunner'
assert p3.company == "Roadrunner"
assert p3.sort_order == 1
p3 = await Product.objects.get(name='Test2')
assert p3.company == 'Roadrunner'
p3 = await Product.objects.get(name="Test2")
assert p3.company == "Roadrunner"
assert p3.sort_order == 1
time.sleep(1)
p2 = await Product.objects.create(name='Test3')
p2 = await Product.objects.create(name="Test3")
assert p2.created is not None
assert p2.company == 'Acme'
assert p2.company == "Acme"
assert p2.sort_order == 10
if Product.db_backend_name() != 'postgresql':
if Product.db_backend_name() != "postgresql":
# postgres use transaction timestamp so it will remain the same
assert p1.created != p2.created # pragma nocover

View File

@ -1,7 +1,7 @@
import asyncio
import sqlite3
import asyncpg # type: ignore
import asyncpg # type: ignore
import databases
import pymysql
import pytest