allow change to build in type hints

This commit is contained in:
collerek
2020-10-31 15:43:34 +01:00
parent 320588a3c1
commit 8fba94efa1
18 changed files with 575 additions and 131 deletions

View File

@ -1,5 +1,6 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Field, typing from pydantic import Field, typing
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -11,8 +12,9 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import NewBaseModel from ormar.models import NewBaseModel
class BaseField: class BaseField(FieldInfo):
__type__ = None __type__ = None
__pydantic_type__ = None
column_type: sqlalchemy.Column column_type: sqlalchemy.Column
constraints: List = [] constraints: List = []
@ -32,6 +34,28 @@ class BaseField:
default: Any default: Any
server_default: Any server_default: Any
@classmethod
def is_valid_field_info_field(cls, field_name: str) -> bool:
return (
field_name not in ["default", "default_factory"]
and not field_name.startswith("__")
and hasattr(cls, field_name)
)
@classmethod
def convert_to_pydantic_field_info(cls, allow_null: bool = False) -> FieldInfo:
base = cls.default_value()
if base is None:
base = (
FieldInfo(default=None)
if (cls.nullable or allow_null)
else FieldInfo(default=pydantic.fields.Undefined)
)
for attr_name in FieldInfo.__dict__.keys():
if cls.is_valid_field_info_field(attr_name):
setattr(base, attr_name, cls.__dict__.get(attr_name))
return base
@classmethod @classmethod
def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]: def default_value(cls, use_server: bool = False) -> Optional[FieldInfo]:
if cls.is_auto_primary_key(): if cls.is_auto_primary_key():

View File

@ -1,4 +1,4 @@
from typing import Any, Generator, List, Optional, TYPE_CHECKING, Type, Union from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy import sqlalchemy
from sqlalchemy import UniqueConstraint from sqlalchemy import UniqueConstraint
@ -39,8 +39,15 @@ def ForeignKey( # noqa CFQ002
ondelete: str = None, ondelete: str = None,
) -> Type["ForeignKeyField"]: ) -> Type["ForeignKeyField"]:
fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname) fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname)
to_field = to.__fields__[to.Meta.pkname] to_field = to.Meta.model_fields[to.Meta.pkname]
__type__ = (
Union[to_field.__type__, to]
if not nullable
else Optional[Union[to_field.__type__, to]]
)
namespace = dict( namespace = dict(
__type__=__type__,
__pydantic_type__=__type__,
to=to, to=to,
name=name, name=name,
nullable=nullable, nullable=nullable,
@ -50,7 +57,7 @@ def ForeignKey( # noqa CFQ002
) )
], ],
unique=unique, unique=unique,
column_type=to_field.type_.column_type, column_type=to_field.column_type,
related_name=related_name, related_name=related_name,
virtual=virtual, virtual=virtual,
primary_key=False, primary_key=False,
@ -58,7 +65,6 @@ def ForeignKey( # noqa CFQ002
pydantic_only=False, pydantic_only=False,
default=None, default=None,
server_default=None, server_default=None,
__pydantic_model__=to,
) )
return type("ForeignKey", (ForeignKeyField, BaseField), namespace) return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
@ -70,14 +76,6 @@ class ForeignKeyField(BaseField):
related_name: str related_name: str
virtual: bool virtual: bool
@classmethod
def __get_validators__(cls) -> Generator:
yield cls.validate
@classmethod
def validate(cls, value: Any) -> Any:
return value
@classmethod @classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(
cls, value: List, child: "Model", to_register: bool cls, value: List, child: "Model", to_register: bool

View File

@ -1,4 +1,4 @@
from typing import Dict, TYPE_CHECKING, Type from typing import Any, List, Optional, TYPE_CHECKING, Type, Union
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
@ -15,17 +15,26 @@ def ManyToMany(
*, *,
name: str = None, name: str = None,
unique: bool = False, unique: bool = False,
related_name: str = None,
virtual: bool = False, virtual: bool = False,
**kwargs: Any
) -> Type["ManyToManyField"]: ) -> Type["ManyToManyField"]:
to_field = to.__fields__[to.Meta.pkname] to_field = to.Meta.model_fields[to.Meta.pkname]
related_name = kwargs.pop("related_name", None)
nullable = kwargs.pop("nullable", True)
__type__ = (
Union[to_field.__type__, to, List[to]] # type: ignore
if not nullable
else Optional[Union[to_field.__type__, to, List[to]]] # type: ignore
)
namespace = dict( namespace = dict(
__type__=__type__,
__pydantic_type__=__type__,
to=to, to=to,
through=through, through=through,
name=name, name=name,
nullable=True, nullable=True,
unique=unique, unique=unique,
column_type=to_field.type_.column_type, column_type=to_field.column_type,
related_name=related_name, related_name=related_name,
virtual=virtual, virtual=virtual,
primary_key=False, primary_key=False,
@ -33,9 +42,6 @@ def ManyToMany(
pydantic_only=False, pydantic_only=False,
default=None, default=None,
server_default=None, server_default=None,
__pydantic_model__=to,
# __origin__=List,
# __args__=[Optional[to]]
) )
return type("ManyToMany", (ManyToManyField, BaseField), namespace) return type("ManyToMany", (ManyToManyField, BaseField), namespace)
@ -43,10 +49,3 @@ def ManyToMany(
class ManyToManyField(ForeignKeyField): class ManyToManyField(ForeignKeyField):
through: Type["Model"] through: Type["Model"]
@classmethod
def __modify_schema__(cls, field_schema: Dict) -> None:
field_schema["type"] = "array"
field_schema["title"] = cls.name.title()
field_schema["definitions"] = {f"{cls.to.__name__}": cls.to.schema()}
field_schema["items"] = {"$ref": f"{REF_PREFIX}{cls.to.__name__}"}

View File

@ -20,8 +20,9 @@ def is_field_nullable(
class ModelFieldFactory: class ModelFieldFactory:
_bases: Any = BaseField _bases: Any = (BaseField,)
_type: Any = None _type: Any = None
_pydantic_type: Any = None
def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore
cls.validate(**kwargs) cls.validate(**kwargs)
@ -32,6 +33,7 @@ class ModelFieldFactory:
namespace = dict( namespace = dict(
__type__=cls._type, __type__=cls._type,
__pydantic_type__=cls._pydantic_type,
name=kwargs.pop("name", None), name=kwargs.pop("name", None),
primary_key=kwargs.pop("primary_key", False), primary_key=kwargs.pop("primary_key", False),
default=default, default=default,
@ -57,8 +59,8 @@ class ModelFieldFactory:
class String(ModelFieldFactory): class String(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str _type = str
_pydantic_type = pydantic.ConstrainedStr
def __new__( # type: ignore # noqa CFQ002 def __new__( # type: ignore # noqa CFQ002
cls, cls,
@ -96,8 +98,8 @@ class String(ModelFieldFactory):
class Integer(ModelFieldFactory): class Integer(ModelFieldFactory):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int _type = int
_pydantic_type = pydantic.ConstrainedInt
def __new__( # type: ignore def __new__( # type: ignore
cls, cls,
@ -131,8 +133,8 @@ class Integer(ModelFieldFactory):
class Text(ModelFieldFactory): class Text(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str _type = str
_pydantic_type = pydantic.ConstrainedStr
def __new__( # type: ignore def __new__( # type: ignore
cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any
@ -154,8 +156,8 @@ class Text(ModelFieldFactory):
class Float(ModelFieldFactory): class Float(ModelFieldFactory):
_bases = (pydantic.ConstrainedFloat, BaseField)
_type = float _type = float
_pydantic_type = pydantic.ConstrainedFloat
def __new__( # type: ignore def __new__( # type: ignore
cls, cls,
@ -183,8 +185,8 @@ class Float(ModelFieldFactory):
class Boolean(ModelFieldFactory): class Boolean(ModelFieldFactory):
_bases = (int, BaseField)
_type = bool _type = bool
_pydantic_type = bool
@classmethod @classmethod
def get_column_type(cls, **kwargs: Any) -> Any: def get_column_type(cls, **kwargs: Any) -> Any:
@ -192,8 +194,8 @@ class Boolean(ModelFieldFactory):
class DateTime(ModelFieldFactory): class DateTime(ModelFieldFactory):
_bases = (datetime.datetime, BaseField)
_type = datetime.datetime _type = datetime.datetime
_pydantic_type = datetime.datetime
@classmethod @classmethod
def get_column_type(cls, **kwargs: Any) -> Any: def get_column_type(cls, **kwargs: Any) -> Any:
@ -201,8 +203,8 @@ class DateTime(ModelFieldFactory):
class Date(ModelFieldFactory): class Date(ModelFieldFactory):
_bases = (datetime.date, BaseField)
_type = datetime.date _type = datetime.date
_pydantic_type = datetime.date
@classmethod @classmethod
def get_column_type(cls, **kwargs: Any) -> Any: def get_column_type(cls, **kwargs: Any) -> Any:
@ -210,8 +212,8 @@ class Date(ModelFieldFactory):
class Time(ModelFieldFactory): class Time(ModelFieldFactory):
_bases = (datetime.time, BaseField)
_type = datetime.time _type = datetime.time
_pydantic_type = datetime.time
@classmethod @classmethod
def get_column_type(cls, **kwargs: Any) -> Any: def get_column_type(cls, **kwargs: Any) -> Any:
@ -219,8 +221,8 @@ class Time(ModelFieldFactory):
class JSON(ModelFieldFactory): class JSON(ModelFieldFactory):
_bases = (pydantic.Json, BaseField)
_type = pydantic.Json _type = pydantic.Json
_pydantic_type = pydantic.Json
@classmethod @classmethod
def get_column_type(cls, **kwargs: Any) -> Any: def get_column_type(cls, **kwargs: Any) -> Any:
@ -228,8 +230,8 @@ class JSON(ModelFieldFactory):
class BigInteger(Integer): class BigInteger(Integer):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int _type = int
_pydantic_type = pydantic.ConstrainedInt
def __new__( # type: ignore def __new__( # type: ignore
cls, cls,
@ -263,8 +265,8 @@ class BigInteger(Integer):
class Decimal(ModelFieldFactory): class Decimal(ModelFieldFactory):
_bases = (pydantic.ConstrainedDecimal, BaseField)
_type = decimal.Decimal _type = decimal.Decimal
_pydantic_type = pydantic.ConstrainedDecimal
def __new__( # type: ignore # noqa CFQ002 def __new__( # type: ignore # noqa CFQ002
cls, cls,
@ -318,8 +320,8 @@ class Decimal(ModelFieldFactory):
class UUID(ModelFieldFactory): class UUID(ModelFieldFactory):
_bases = (uuid.UUID, BaseField)
_type = uuid.UUID _type = uuid.UUID
_pydantic_type = uuid.UUID
@classmethod @classmethod
def get_column_type(cls, **kwargs: Any) -> Any: def get_column_type(cls, **kwargs: Any) -> Any:

View File

@ -5,7 +5,8 @@ import databases
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig from pydantic import BaseConfig
from pydantic.fields import FieldInfo, ModelField from pydantic.fields import ModelField
from pydantic.utils import lenient_issubclass
from sqlalchemy.sql.schema import ColumnCollectionConstraint from sqlalchemy.sql.schema import ColumnCollectionConstraint
import ormar # noqa I100 import ormar # noqa I100
@ -179,44 +180,58 @@ def register_relation_in_alias_manager(
def populate_default_pydantic_field_value( def populate_default_pydantic_field_value(
type_: Type[BaseField], field: str, attrs: dict ormar_field: Type[BaseField], field_name: str, attrs: dict
) -> dict: ) -> dict:
def_value = type_.default_value() curr_def_value = attrs.get(field_name, ormar.Undefined)
curr_def_value = attrs.get(field, "NONE") if lenient_issubclass(curr_def_value, ormar.fields.BaseField):
if curr_def_value == "NONE" and isinstance(def_value, FieldInfo): curr_def_value = ormar.Undefined
attrs[field] = def_value if curr_def_value is None:
elif curr_def_value == "NONE" and type_.nullable: attrs[field_name] = ormar_field.convert_to_pydantic_field_info(allow_null=True)
attrs[field] = FieldInfo(default=None) else:
attrs[field_name] = ormar_field.convert_to_pydantic_field_info()
return attrs return attrs
def populate_pydantic_default_values(attrs: Dict) -> Dict: def check_if_field_annotation_or_value_is_ormar(
for field, type_ in attrs["__annotations__"].items(): field: Any, field_name: str, attrs: Dict
if issubclass(type_, BaseField): ) -> bool:
if type_.name is None: return lenient_issubclass(field, BaseField) or issubclass(
type_.name = field attrs.get(field_name, type), BaseField
attrs = populate_default_pydantic_field_value(type_, field, attrs) )
return attrs
def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: def extract_field_from_annotation_or_value(
field: Any, field_name: str, attrs: Dict
) -> Type[ormar.fields.BaseField]:
return field if lenient_issubclass(field, BaseField) else attrs.get(field_name)
def populate_pydantic_default_values(attrs: Dict) -> Tuple[Dict, Dict]:
model_fields = {}
for field_name, field in attrs["__annotations__"].items():
# ormar fields can be used as annotation or as default value
if check_if_field_annotation_or_value_is_ormar(field, field_name, attrs):
ormar_field = extract_field_from_annotation_or_value(
field, field_name, attrs
)
if ormar_field.name is None:
ormar_field.name = field_name
attrs = populate_default_pydantic_field_value(
ormar_field, field_name, attrs
)
model_fields[field_name] = ormar_field
attrs["__annotations__"][field_name] = ormar_field.__type__
return attrs, model_fields
def extract_annotations_and_default_vals(
attrs: dict, bases: Tuple
) -> Tuple[Dict, Dict]:
attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get( attrs["__annotations__"] = attrs.get("__annotations__") or bases[0].__dict__.get(
"__annotations__", {} "__annotations__", {}
) )
attrs = populate_pydantic_default_values(attrs) attrs, model_fields = populate_pydantic_default_values(attrs)
return attrs return attrs, model_fields
def populate_meta_orm_model_fields(
attrs: dict, new_model: Type["Model"]
) -> Type["Model"]:
model_fields = {
field_name: field
for field_name, field in attrs["__annotations__"].items()
if issubclass(field, BaseField)
}
new_model.Meta.model_fields = model_fields
return new_model
def populate_meta_tablename_columns_and_pk( def populate_meta_tablename_columns_and_pk(
@ -305,7 +320,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
) -> "ModelMetaclass": ) -> "ModelMetaclass":
attrs["Config"] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name attrs["__name__"] = name
attrs = extract_annotations_and_default_vals(attrs, bases) attrs, model_fields = extract_annotations_and_default_vals(attrs, bases)
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
@ -313,7 +328,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
if hasattr(new_model, "Meta"): if hasattr(new_model, "Meta"):
if not hasattr(new_model.Meta, "constraints"): if not hasattr(new_model.Meta, "constraints"):
new_model.Meta.constraints = [] new_model.Meta.constraints = []
new_model = populate_meta_orm_model_fields(attrs, new_model) if not hasattr(new_model.Meta, "model_fields"):
new_model.Meta.model_fields = model_fields
new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model)
new_model = populate_meta_sqlalchemy_table_if_required(new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model)
expand_reverse_relationships(new_model) expand_reverse_relationships(new_model)
@ -322,7 +338,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
if new_model.Meta.pkname not in attrs["__annotations__"]: if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname field_name = new_model.Meta.pkname
field = Integer(name=field_name, primary_key=True) field = Integer(name=field_name, primary_key=True)
attrs["__annotations__"][field_name] = field attrs["__annotations__"][field_name] = Optional[int] # type: ignore
populate_default_pydantic_field_value( populate_default_pydantic_field_value(
field, field_name, attrs # type: ignore field, field_name, attrs # type: ignore
) )

View File

@ -1,11 +1,12 @@
import itertools import itertools
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
import sqlalchemy import sqlalchemy
import ormar.queryset # noqa I100 import ormar.queryset # noqa I100
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
from ormar.models.metaclass import ModelMeta
def group_related_list(list_: List) -> Dict: def group_related_list(list_: List) -> Dict:
@ -23,18 +24,30 @@ def group_related_list(list_: List) -> Dict:
return test_dict return test_dict
T = TypeVar("T", bound="Model")
class Model(NewBaseModel): class Model(NewBaseModel):
__abstract__ = False __abstract__ = False
if TYPE_CHECKING: # pragma nocover
Meta: ModelMeta
def __repr__(self) -> str: # pragma nocover
attrs_to_include = ["tablename", "columns", "pkname"]
_repr = {k: v for k, v in self.Meta.model_fields.items()}
for atr in attrs_to_include:
_repr[atr] = getattr(self.Meta, atr)
return f"{self.__class__.__name__}({str(_repr)})"
@classmethod @classmethod
def from_row( # noqa CCR001 def from_row( # noqa CCR001
cls, cls: Type[T],
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
previous_table: str = None, previous_table: str = None,
fields: List = None, fields: List = None,
) -> Optional["Model"]: ) -> Optional[T]:
item: Dict[str, Any] = {} item: Dict[str, Any] = {}
select_related = select_related or [] select_related = select_related or []
@ -66,7 +79,9 @@ class Model(NewBaseModel):
item, row, table_prefix, fields, nested=table_prefix != "" item, row, table_prefix, fields, nested=table_prefix != ""
) )
instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None instance: Optional[T] = cls(**item) if item.get(
cls.Meta.pkname, None
) is not None else None
return instance return instance
@classmethod @classmethod
@ -124,7 +139,7 @@ class Model(NewBaseModel):
return item return item
async def save(self) -> "Model": async def save(self: T) -> T:
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement:
@ -137,7 +152,7 @@ class Model(NewBaseModel):
setattr(self, self.Meta.pkname, item_id) setattr(self, self.Meta.pkname, item_id)
return self return self
async def update(self, **kwargs: Any) -> "Model": async def update(self: T, **kwargs: Any) -> T:
if kwargs: if kwargs:
new_values = {**self.dict(), **kwargs} new_values = {**self.dict(), **kwargs}
self.from_dict(new_values) self.from_dict(new_values)
@ -151,13 +166,13 @@ class Model(NewBaseModel):
await self.Meta.database.execute(expr) await self.Meta.database.execute(expr)
return self return self
async def delete(self) -> int: async def delete(self: T) -> int:
expr = self.Meta.table.delete() expr = self.Meta.table.delete()
expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
result = await self.Meta.database.execute(expr) result = await self.Meta.database.execute(expr)
return result return result
async def load(self) -> "Model": async def load(self: T) -> T:
expr = self.Meta.table.select().where(self.pk_column == self.pk) expr = self.Meta.table.select().where(self.pk_column == self.pk)
row = await self.Meta.database.fetch_one(expr) row = await self.Meta.database.fetch_one(expr)
if not row: # pragma nocover if not row: # pragma nocover

View File

@ -1,5 +1,5 @@
import inspect import inspect
from typing import Dict, List, Set, TYPE_CHECKING, Type, TypeVar, Union from typing import Dict, List, Sequence, Set, TYPE_CHECKING, Type, TypeVar, Union
import ormar import ormar
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.models import NewBaseModel from ormar.models import NewBaseModel
T = TypeVar("T", bound=Model)
Field = TypeVar("Field", bound=BaseField) Field = TypeVar("Field", bound=BaseField)
@ -135,7 +137,7 @@ class ModelTableProxy:
if field.to == related.__class__ or field.to.Meta == related.Meta: if field.to == related.__class__ or field.to.Meta == related.Meta:
return name return name
# fallback for not registered relation # fallback for not registered relation
if register_missing: if register_missing: # pragma nocover
expand_reverse_relationships(related.__class__) # type: ignore expand_reverse_relationships(related.__class__) # type: ignore
return ModelTableProxy.resolve_relation_name( return ModelTableProxy.resolve_relation_name(
item, related, register_missing=False item, related, register_missing=False
@ -177,7 +179,7 @@ class ModelTableProxy:
return new_kwargs return new_kwargs
@classmethod @classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: def merge_instances_list(cls, result_rows: Sequence["Model"]) -> Sequence["Model"]:
merged_rows: List["Model"] = [] merged_rows: List["Model"] = []
for index, model in enumerate(result_rows): for index, model in enumerate(result_rows):
if index > 0 and model is not None and model.pk == merged_rows[-1].pk: if index > 0 and model is not None and model.pk == merged_rows[-1].pk:

View File

@ -5,11 +5,12 @@ from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
List,
Mapping, Mapping,
Optional, Optional,
Sequence,
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar,
Union, Union,
) )
@ -27,7 +28,9 @@ from ormar.relations.alias_manager import AliasManager
from ormar.relations.relation_manager import RelationsManager from ormar.relations.relation_manager import RelationsManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models.model import Model from ormar import Model
T = TypeVar("T", bound=Model)
IntStr = Union[int, str] IntStr = Union[int, str]
DictStrAny = Dict[str, Any] DictStrAny = Dict[str, Any]
@ -52,7 +55,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
Meta: ModelMeta Meta: ModelMeta
# noinspection PyMissingConstructor # noinspection PyMissingConstructor
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None: # type: ignore
object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False) object.__setattr__(self, "_orm_saved", False)
@ -73,7 +76,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register # build the models to set them and validate but don't register
kwargs = { new_kwargs = {
k: self._convert_json( k: self._convert_json(
k, k,
self.Meta.model_fields[k].expand_relationship( self.Meta.model_fields[k].expand_relationship(
@ -85,7 +88,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
} }
values, fields_set, validation_error = pydantic.validate_model( values, fields_set, validation_error = pydantic.validate_model(
self, kwargs # type: ignore self, new_kwargs # type: ignore
) )
if validation_error and not pk_only: if validation_error and not pk_only:
raise validation_error raise validation_error
@ -96,7 +99,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
# register the columns models after initialization # register the columns models after initialization
for related in self.extract_related_names(): for related in self.extract_related_names():
self.Meta.model_fields[related].expand_relationship( self.Meta.model_fields[related].expand_relationship(
kwargs.get(related), self, to_register=True new_kwargs.get(related), self, to_register=True
) )
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
@ -133,7 +136,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def _extract_related_model_instead_of_field( def _extract_related_model_instead_of_field(
self, item: str self, item: str
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union[T, Sequence[T]]]:
alias = self.get_column_alias(item) alias = self.get_column_alias(item)
if alias in self._orm: if alias in self._orm:
return self._orm.get(alias) return self._orm.get(alias)
@ -170,7 +173,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def db_backend_name(cls) -> str: def db_backend_name(cls) -> str:
return cls.Meta.database._backend._dialect.name return cls.Meta.database._backend._dialect.name
def remove(self, name: "Model") -> None: def remove(self, name: T) -> None:
self._orm.remove_parent(self, name) self._orm.remove_parent(self, name)
def dict( # noqa A003 def dict( # noqa A003

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Type, Union
import databases import databases
import sqlalchemy import sqlalchemy
@ -59,7 +59,7 @@ class QuerySet:
raise ValueError("Model class of QuerySet is not initialized") raise ValueError("Model class of QuerySet is not initialized")
return self.model_cls return self.model_cls
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]: def _process_query_result_rows(self, rows: List) -> Sequence[Optional["Model"]]:
result_rows = [ result_rows = [
self.model.from_row( self.model.from_row(
row, select_related=self._select_related, fields=self._columns row, select_related=self._select_related, fields=self._columns
@ -87,7 +87,7 @@ class QuerySet:
return new_kwargs return new_kwargs
@staticmethod @staticmethod
def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None: def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None:
if not rows or rows[0] is None: if not rows or rows[0] is None:
raise NoMatch() raise NoMatch()
if len(rows) > 1: if len(rows) > 1:
@ -267,7 +267,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name]) model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs) return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
if kwargs: if kwargs:
return await self.filter(**kwargs).all() return await self.filter(**kwargs).all()

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, TYPE_CHECKING, Union from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union
import ormar import ormar
@ -7,6 +7,8 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
T = TypeVar("T", bound=Model)
class QuerysetProxy: class QuerysetProxy:
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -26,27 +28,28 @@ class QuerysetProxy:
def queryset(self, value: "QuerySet") -> None: def queryset(self, value: "QuerySet") -> None:
self._queryset = value self._queryset = value
def _assign_child_to_parent(self, child: Optional["Model"]) -> None: def _assign_child_to_parent(self, child: Optional[T]) -> None:
if child: if child:
owner = self.relation._owner owner = self.relation._owner
rel_name = owner.resolve_relation_name(owner, child) rel_name = owner.resolve_relation_name(owner, child)
setattr(owner, rel_name, child) setattr(owner, rel_name, child)
def _register_related(self, child: Union["Model", List[Optional["Model"]]]) -> None: def _register_related(self, child: Union[T, Sequence[Optional[T]]]) -> None:
if isinstance(child, list): if isinstance(child, list):
for subchild in child: for subchild in child:
self._assign_child_to_parent(subchild) self._assign_child_to_parent(subchild)
else: else:
assert isinstance(child, Model)
self._assign_child_to_parent(child) self._assign_child_to_parent(child)
async def create_through_instance(self, child: "Model") -> None: async def create_through_instance(self, child: T) -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through) queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name() owner_column = self.relation._owner.get_name()
child_column = child.get_name() child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child} kwargs = {owner_column: self.relation._owner, child_column: child}
await queryset.create(**kwargs) await queryset.create(**kwargs)
async def delete_through_instance(self, child: "Model") -> None: async def delete_through_instance(self, child: T) -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through) queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name() owner_column = self.relation._owner.get_name()
child_column = child.get_name() child_column = child.get_name()
@ -88,7 +91,7 @@ class QuerysetProxy:
self._register_related(get) self._register_related(get)
return get return get
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
all_items = await self.queryset.all(**kwargs) all_items = await self.queryset.all(**kwargs)
self._register_related(all_items) self._register_related(all_items)
return all_items return all_items

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import List, Optional, TYPE_CHECKING, Type, Union from typing import List, Optional, TYPE_CHECKING, Type, TypeVar, Union
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100
@ -11,6 +11,8 @@ if TYPE_CHECKING: # pragma no cover
from ormar.relations import RelationsManager from ormar.relations import RelationsManager
from ormar.models import NewBaseModel from ormar.models import NewBaseModel
T = TypeVar("T", bound=Model)
class RelationType(Enum): class RelationType(Enum):
PRIMARY = 1 PRIMARY = 1
@ -23,15 +25,15 @@ class Relation:
self, self,
manager: "RelationsManager", manager: "RelationsManager",
type_: RelationType, type_: RelationType,
to: Type["Model"], to: Type[T],
through: Type["Model"] = None, through: Type[T] = None,
) -> None: ) -> None:
self.manager = manager self.manager = manager
self._owner: "Model" = manager.owner self._owner: "Model" = manager.owner
self._type: RelationType = type_ self._type: RelationType = type_
self.to: Type["Model"] = to self.to: Type[T] = to
self.through: Optional[Type["Model"]] = through self.through: Optional[Type[T]] = through
self.related_models: Optional[Union[RelationProxy, "Model"]] = ( self.related_models: Optional[Union[RelationProxy, T]] = (
RelationProxy(relation=self) RelationProxy(relation=self)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
@ -50,7 +52,7 @@ class Relation:
self.related_models.pop(ind) self.related_models.pop(ind)
return None return None
def add(self, child: "Model") -> None: def add(self, child: T) -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child) relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY: if self._type == RelationType.PRIMARY:
self.related_models = child self.related_models = child
@ -77,7 +79,7 @@ class Relation:
self.related_models.pop(position) # type: ignore self.related_models.pop(position) # type: ignore
del self._owner.__dict__[relation_name][position] del self._owner.__dict__[relation_name][position]
def get(self) -> Optional[Union[List["Model"], "Model"]]: def get(self) -> Optional[Union[List[T], T]]:
return self.related_models return self.related_models
def __repr__(self) -> str: # pragma no cover def __repr__(self) -> str: # pragma no cover

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, TYPE_CHECKING, Type, Union from typing import Dict, List, Optional, Sequence, TYPE_CHECKING, Type, TypeVar, Union
from weakref import proxy from weakref import proxy
from ormar.fields import BaseField from ormar.fields import BaseField
@ -14,6 +14,8 @@ if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
from ormar.models import NewBaseModel from ormar.models import NewBaseModel
T = TypeVar("T", bound=Model)
class RelationsManager: class RelationsManager:
def __init__( def __init__(
@ -46,7 +48,7 @@ class RelationsManager:
def __contains__(self, item: str) -> bool: def __contains__(self, item: str) -> bool:
return item in self._related_names return item in self._related_names
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]: def get(self, name: str) -> Optional[Union[T, Sequence[T]]]:
relation = self._relations.get(name, None) relation = self._relations.get(name, None)
if relation is not None: if relation is not None:
return relation.get() return relation.get()

View File

@ -72,6 +72,6 @@ class RelationProxy(list):
if self.relation._type == ormar.RelationType.MULTIPLE: if self.relation._type == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item) await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner) rel_name = item.resolve_relation_name(item, self._owner)
if rel_name not in item._orm: if rel_name not in item._orm: # pragma nocover
item._orm._add_relation(item.Meta.model_fields[rel_name]) item._orm._add_relation(item.Meta.model_fields[rel_name])
setattr(item, rel_name, self._owner) setattr(item, rel_name, self._owner)

View File

@ -125,7 +125,9 @@ def test_all_endpoints():
def test_schema_modification(): def test_schema_modification():
schema = Item.schema() schema = Item.schema()
assert schema["properties"]["categories"]["type"] == "array" assert any(
x.get("type") == "array" for x in schema["properties"]["categories"]["anyOf"]
)
assert schema["properties"]["categories"]["title"] == "Categories" assert schema["properties"]["categories"]["title"] == "Categories"

View File

@ -98,9 +98,9 @@ async def create_test_database():
def test_model_class(): def test_model_class():
assert list(User.Meta.model_fields.keys()) == ["id", "name"] assert list(User.Meta.model_fields.keys()) == ["id", "name"]
assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt) assert issubclass(User.Meta.model_fields["id"], pydantic.fields.FieldInfo)
assert User.Meta.model_fields["id"].primary_key is True assert User.Meta.model_fields["id"].primary_key is True
assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr) assert issubclass(User.Meta.model_fields["name"], pydantic.fields.FieldInfo)
assert User.Meta.model_fields["name"].max_length == 100 assert User.Meta.model_fields["name"].max_length == 100
assert isinstance(User.Meta.table, sqlalchemy.Table) assert isinstance(User.Meta.table, sqlalchemy.Table)

View File

@ -0,0 +1,374 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Album = ormar.ForeignKey(Album, related_name="cover_pictures")
title: str = ormar.String(max_length=100)
class Organisation(ormar.Model):
class Meta:
tablename = "org"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
class Team(ormar.Model):
class Meta:
tablename = "teams"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
org: Optional[Organisation] = ormar.ForeignKey(Organisation)
name: str = ormar.String(max_length=100)
class Member(ormar.Model):
class Meta:
tablename = "members"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
team: Optional[Team] = ormar.ForeignKey(Team)
email: str = ormar.String(max_length=100)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_wrong_query_foreign_key_type():
async with database:
with pytest.raises(RelationshipInstanceError):
Track(title="The Error", album="wrong_pk_type")
@pytest.mark.asyncio
async def test_setting_explicitly_empty_relation():
async with database:
track = Track(album=None, title="The Bird", position=1)
assert track.album is None
@pytest.mark.asyncio
async def test_related_name():
async with database:
async with database.transaction(force_rollback=True):
album = await Album.objects.create(name="Vanilla")
await Cover.objects.create(album=album, title="The cover file")
assert len(album.cover_pictures) == 1
@pytest.mark.asyncio
async def test_model_crud():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Jamaica")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model)
assert track.album.name is None
await track.album.load()
assert track.album.name == "Jamaica"
assert len(album.tracks) == 3
assert album.tracks[1].title == "Heart don't stand a chance"
album1 = await Album.objects.get(name="Jamaica")
assert album1.pk == album.pk
assert album1.tracks == []
await Track.objects.create(
album={"id": track.album.pk}, title="The Bird2", position=4
)
@pytest.mark.asyncio
async def test_select_related():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Malibu")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
fantasies = Album(name="Fantasies")
await fantasies.save()
track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
track5 = Track(album=fantasies, title="Sick Muse", position=2)
track6 = Track(album=fantasies, title="Satellite Mind", position=3)
await track4.save()
await track5.save()
await track6.save()
track = await Track.objects.select_related("album").get(title="The Bird")
assert track.album.name == "Malibu"
tracks = await Track.objects.select_related("album").all()
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_model_removal_from_relations():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Chichi")
await album.save()
track1 = Track(album=album, title="The Birdman", position=1)
track2 = Track(album=album, title="Superman", position=2)
track3 = Track(album=album, title="Wonder Woman", position=3)
await track1.save()
await track2.save()
await track3.save()
assert len(album.tracks) == 3
await album.tracks.remove(track1)
assert len(album.tracks) == 2
assert track1.album is None
await track1.update()
track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None
await album.tracks.add(track1)
assert len(album.tracks) == 3
assert track1.album == album
await track1.update()
track1 = await Track.objects.select_related("album__tracks").get(
title="The Birdman"
)
album = await Album.objects.select_related("tracks").get(name="Chichi")
assert track1.album == album
track1.remove(album)
assert track1.album is None
assert len(album.tracks) == 2
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
@pytest.mark.asyncio
async def test_fk_filter():
async with database:
async with database.transaction(force_rollback=True):
malibu = Album(name="Malibu%")
await malibu.save()
await Track.objects.create(album=malibu, title="The Bird", position=1)
await Track.objects.create(
album=malibu, title="Heart don't stand a chance", position=2
)
await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies")
await Track.objects.create(
album=fantasies, title="Help I'm Alive", position=1
)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(
album=fantasies, title="Satellite Mind", position=3
)
tracks = (
await Track.objects.select_related("album")
.filter(album__name="Fantasies")
.all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = (
await Track.objects.select_related("album")
.filter(album__name__icontains="fan")
.all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Fan").all()
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
assert len(tracks) == 3
tracks = (
await Track.objects.filter(album=malibu).select_related("album").all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
tracks = await Track.objects.select_related("album").all(album=malibu)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
@pytest.mark.asyncio
async def test_multiple_fk():
async with database:
async with database.transaction(force_rollback=True):
acme = await Organisation.objects.create(ident="ACME Ltd")
red_team = await Team.objects.create(org=acme, name="Red Team")
blue_team = await Team.objects.create(org=acme, name="Blue Team")
await Member.objects.create(team=red_team, email="a@example.org")
await Member.objects.create(team=red_team, email="b@example.org")
await Member.objects.create(team=blue_team, email="c@example.org")
await Member.objects.create(team=blue_team, email="d@example.org")
other = await Organisation.objects.create(ident="Other ltd")
team = await Team.objects.create(org=other, name="Green Team")
await Member.objects.create(team=team, email="e@example.org")
members = (
await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd")
.all()
)
assert len(members) == 4
for member in members:
assert member.team.org.ident == "ACME Ltd"
@pytest.mark.asyncio
async def test_wrong_choices():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(ValueError):
await Organisation.objects.create(ident="Test 1")
@pytest.mark.asyncio
async def test_pk_filter():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
track = await Track.objects.create(
album=fantasies, title="Test1", position=1
)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
tracks = (
await Track.objects.select_related("album").filter(pk=track.pk).all()
)
assert len(tracks) == 1
tracks = (
await Track.objects.select_related("album")
.filter(position=2, album__name="Test")
.all()
)
assert len(tracks) == 1
@pytest.mark.asyncio
async def test_limit_and_offset():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Limitless")
await Track.objects.create(
id=None, album=fantasies, title="Sample", position=1
)
await Track.objects.create(album=fantasies, title="Sample2", position=2)
await Track.objects.create(album=fantasies, title="Sample3", position=3)
tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample"
tracks = await Track.objects.limit(1).offset(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample2"
@pytest.mark.asyncio
async def test_get_exceptions():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
with pytest.raises(NoMatch):
await Album.objects.get(name="Test2")
await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
with pytest.raises(MultipleMatches):
await Track.objects.select_related("album").get(album=fantasies)
@pytest.mark.asyncio
async def test_wrong_model_passed_as_fk():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(RelationshipInstanceError):
org = await Organisation.objects.create(ident="ACME Ltd")
await Track.objects.create(album=org, title="Test1", position=1)

View File

@ -22,7 +22,7 @@ class Product(ormar.Model):
id: ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100) name: ormar.String(max_length=100)
company: ormar.String(max_length=200, server_default='Acme') company: ormar.String(max_length=200, server_default="Acme")
sort_order: ormar.Integer(server_default=text("10")) sort_order: ormar.Integer(server_default=text("10"))
created: ormar.DateTime(server_default=func.now()) created: ormar.DateTime(server_default=func.now())
@ -44,42 +44,44 @@ async def create_test_database():
def test_table_defined_properly(): def test_table_defined_properly():
assert Product.Meta.model_fields['created'].nullable assert Product.Meta.model_fields["created"].nullable
assert not Product.__fields__['created'].required assert not Product.__fields__["created"].required
assert Product.Meta.table.columns['created'].server_default.arg.name == 'now' assert Product.Meta.table.columns["created"].server_default.arg.name == "now"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_creation(): async def test_model_creation():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
p1 = Product(name='Test') p1 = Product(name="Test")
assert p1.created is None assert p1.created is None
await p1.save() await p1.save()
await p1.load() await p1.load()
assert p1.created is not None assert p1.created is not None
assert p1.company == 'Acme' assert p1.company == "Acme"
assert p1.sort_order == 10 assert p1.sort_order == 10
date = datetime.strptime('2020-10-27 11:30', '%Y-%m-%d %H:%M') date = datetime.strptime("2020-10-27 11:30", "%Y-%m-%d %H:%M")
p3 = await Product.objects.create(name='Test2', created=date, company='Roadrunner', sort_order=1) p3 = await Product.objects.create(
name="Test2", created=date, company="Roadrunner", sort_order=1
)
assert p3.created is not None assert p3.created is not None
assert p3.created == date assert p3.created == date
assert p1.created != p3.created assert p1.created != p3.created
assert p3.company == 'Roadrunner' assert p3.company == "Roadrunner"
assert p3.sort_order == 1 assert p3.sort_order == 1
p3 = await Product.objects.get(name='Test2') p3 = await Product.objects.get(name="Test2")
assert p3.company == 'Roadrunner' assert p3.company == "Roadrunner"
assert p3.sort_order == 1 assert p3.sort_order == 1
time.sleep(1) time.sleep(1)
p2 = await Product.objects.create(name='Test3') p2 = await Product.objects.create(name="Test3")
assert p2.created is not None assert p2.created is not None
assert p2.company == 'Acme' assert p2.company == "Acme"
assert p2.sort_order == 10 assert p2.sort_order == 10
if Product.db_backend_name() != 'postgresql': if Product.db_backend_name() != "postgresql":
# postgres use transaction timestamp so it will remain the same # postgres use transaction timestamp so it will remain the same
assert p1.created != p2.created # pragma nocover assert p1.created != p2.created # pragma nocover