Merge pull request #2 from collerek/pydantic_basemodel

Pydantic basemodel
This commit is contained in:
collerek
2020-08-23 21:20:22 +07:00
committed by GitHub
20 changed files with 883 additions and 557 deletions

BIN
.coverage

Binary file not shown.

View File

@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING from typing import Any, List, Optional, TYPE_CHECKING
import sqlalchemy import sqlalchemy
from pydantic import Field
from ormar import ModelDefinitionError # noqa I101 from ormar import ModelDefinitionError # noqa I101
@ -11,72 +12,55 @@ if TYPE_CHECKING: # pragma no cover
class BaseField: class BaseField:
__type__ = None __type__ = None
def __init__(self, **kwargs: Any) -> None: column_type: sqlalchemy.Column
self.name = None constraints: List = []
self._populate_from_kwargs(kwargs)
def _populate_from_kwargs(self, kwargs: Dict) -> None: primary_key: bool
self.primary_key = kwargs.pop("primary_key", False) autoincrement: bool
self.autoincrement = kwargs.pop( nullable: bool
"autoincrement", self.primary_key and self.__type__ == int index: bool
) unique: bool
pydantic_only: bool
self.nullable = kwargs.pop("nullable", not self.primary_key) default: Any
self.default = kwargs.pop("default", None) server_default: Any
self.server_default = kwargs.pop("server_default", None)
self.index = kwargs.pop("index", None) @classmethod
self.unique = kwargs.pop("unique", None) def default_value(cls) -> Optional[Field]:
if cls.is_auto_primary_key():
return Field(default=None)
if cls.has_default():
default = cls.default if cls.default is not None else cls.server_default
if callable(default):
return Field(default_factory=default)
else:
return Field(default=default)
return None
self.pydantic_only = kwargs.pop("pydantic_only", False) @classmethod
if self.pydantic_only and self.primary_key: def has_default(cls) -> bool:
raise ModelDefinitionError("Primary key column cannot be pydantic only.") return cls.default is not None or cls.server_default is not None
@property @classmethod
def is_required(self) -> bool: def is_auto_primary_key(cls) -> bool:
return ( if cls.primary_key:
not self.nullable and not self.has_default and not self.is_auto_primary_key return cls.autoincrement
)
@property
def default_value(self) -> Any:
default = self.default
return default() if callable(default) else default
@property
def has_default(self) -> bool:
return self.default is not None or self.server_default is not None
@property
def is_auto_primary_key(self) -> bool:
if self.primary_key:
return self.autoincrement
return False return False
def get_column(self, name: str = None) -> sqlalchemy.Column: @classmethod
self.name = name def get_column(cls, name: str) -> sqlalchemy.Column:
constraints = self.get_constraints()
return sqlalchemy.Column( return sqlalchemy.Column(
self.name, name,
self.get_column_type(), cls.column_type,
*constraints, *cls.constraints,
primary_key=self.primary_key, primary_key=cls.primary_key,
autoincrement=self.autoincrement, nullable=cls.nullable and not cls.primary_key,
nullable=self.nullable, index=cls.index,
index=self.index, unique=cls.unique,
unique=self.unique, default=cls.default,
default=self.default, server_default=cls.server_default,
server_default=self.server_default,
) )
def get_column_type(self) -> sqlalchemy.types.TypeEngine: @classmethod
raise NotImplementedError() # pragma: no cover def expand_relationship(cls, value: Any, child: "Model") -> Any:
def get_constraints(self) -> Optional[List]:
return []
def expand_relationship(self, value: Any, child: "Model") -> Any:
return value return value
def __repr__(self): # pragma no cover
return str(self.__dict__)

View File

@ -1,27 +0,0 @@
from typing import Any, TYPE_CHECKING, Type
from ormar import ModelDefinitionError
if TYPE_CHECKING: # pragma no cover
from ormar.fields import BaseField
class RequiredParams:
def __init__(self, *args: str) -> None:
self._required = list(args)
def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]:
old_init = model_field_class.__init__
model_field_class._old_init = old_init
def __init__(instance: "BaseField", **kwargs: Any) -> None:
super(instance.__class__, instance).__init__(**kwargs)
for arg in self._required:
if arg not in kwargs:
raise ModelDefinitionError(
f"{instance.__class__.__name__} field requires parameter: {arg}"
)
setattr(instance, arg, kwargs.pop(arg))
model_field_class.__init__ = __init__
return model_field_class

View File

@ -1,7 +1,6 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy import sqlalchemy
from pydantic import BaseModel
import ormar # noqa I101 import ormar # noqa I101
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
@ -13,87 +12,120 @@ if TYPE_CHECKING: # pragma no cover
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
init_dict = { init_dict = {
**{fk.__pkname__: pk or -1}, **{fk.Meta.pkname: pk or -1, "__pk_only__": True},
**{ **{
k: create_dummy_instance(v.to) k: create_dummy_instance(v.to)
for k, v in fk.__model_fields__.items() for k, v in fk.Meta.model_fields.items()
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual
}, },
} }
return fk(**init_dict) return fk(**init_dict)
class ForeignKey(BaseField): def ForeignKey(
def __init__(
self,
to: Type["Model"], to: Type["Model"],
*,
name: str = None, name: str = None,
related_name: str = None, unique: bool = False,
nullable: bool = True, nullable: bool = True,
related_name: str = None,
virtual: bool = False, virtual: bool = False,
) -> None: ) -> Type[object]:
super().__init__(nullable=nullable, name=name) fk_string = to.Meta.tablename + "." + to.Meta.pkname
self.virtual = virtual to_field = to.__fields__[to.Meta.pkname]
self.related_name = related_name namespace = dict(
self.to = to to=to,
name=name,
nullable=nullable,
constraints=[sqlalchemy.schema.ForeignKey(fk_string)],
unique=unique,
column_type=to_field.type_.column_type,
related_name=related_name,
virtual=virtual,
primary_key=False,
index=False,
pydantic_only=False,
default=None,
server_default=None,
)
@property return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
def __type__(self) -> Type[BaseModel]:
return self.to.__pydantic_model__
def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]:
fk_string = self.to.__tablename__ + "." + self.to.__pkname__
return [sqlalchemy.schema.ForeignKey(fk_string)]
def get_column_type(self) -> sqlalchemy.Column: class ForeignKeyField(BaseField):
to_column = self.to.__model_fields__[self.to.__pkname__] to: Type["Model"]
return to_column.get_column_type() related_name: str
virtual: bool
def _extract_model_from_sequence( @classmethod
self, value: List, child: "Model" def __get_validators__(cls) -> Callable:
) -> Union["Model", List["Model"]]: yield cls.validate
return [self.expand_relationship(val, child) for val in value]
def _register_existing_model(self, value: "Model", child: "Model") -> "Model": @classmethod
self.register_relation(value, child) def validate(cls, value: Any) -> Any:
return value return value
def _construct_model_from_dict(self, value: dict, child: "Model") -> "Model": # @property
model = self.to(**value) # def __type__(self) -> Type[BaseModel]:
self.register_relation(model, child) # return self.to.__pydantic_model__
# @classmethod
# def get_column_type(cls) -> sqlalchemy.Column:
# to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname]
# return to_column.column_type
@classmethod
def _extract_model_from_sequence(
cls, value: List, child: "Model"
) -> Union["Model", List["Model"]]:
return [cls.expand_relationship(val, child) for val in value]
@classmethod
def _register_existing_model(cls, value: "Model", child: "Model") -> "Model":
cls.register_relation(value, child)
return value
@classmethod
def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model":
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname:
value["__pk_only__"] = True
model = cls.to(**value)
cls.register_relation(model, child)
return model return model
def _construct_model_from_pk(self, value: Any, child: "Model") -> "Model": @classmethod
if not isinstance(value, self.to.pk_type()): def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model":
if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError( raise RelationshipInstanceError(
f"Relationship error - ForeignKey {self.to.__name__} " f"Relationship error - ForeignKey {cls.to.__name__} "
f"is of type {self.to.pk_type()} " f"is of type {cls.to.pk_type()} "
f"while {type(value)} passed as a parameter." f"while {type(value)} passed as a parameter."
) )
model = create_dummy_instance(fk=self.to, pk=value) model = create_dummy_instance(fk=cls.to, pk=value)
self.register_relation(model, child) cls.register_relation(model, child)
return model return model
def register_relation(self, model: "Model", child: "Model") -> None: @classmethod
child_model_name = self.related_name or child.get_name() def register_relation(cls, model: "Model", child: "Model") -> None:
model._orm_relationship_manager.add_relation( child_model_name = cls.related_name or child.get_name()
model, child, child_model_name, virtual=self.virtual model.Meta._orm_relationship_manager.add_relation(
model, child, child_model_name, virtual=cls.virtual
) )
@classmethod
def expand_relationship( def expand_relationship(
self, value: Any, child: "Model" cls, value: Any, child: "Model"
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if value is None: if value is None:
return None return None
constructors = { constructors = {
f"{self.to.__name__}": self._register_existing_model, f"{cls.to.__name__}": cls._register_existing_model,
"dict": self._construct_model_from_dict, "dict": cls._construct_model_from_dict,
"list": self._extract_model_from_sequence, "list": cls._extract_model_from_sequence,
} }
model = constructors.get( model = constructors.get(
value.__class__.__name__, self._construct_model_from_pk value.__class__.__name__, cls._construct_model_from_pk
)(value, child) )(value, child)
return model return model

View File

@ -1,87 +1,269 @@
import datetime import datetime
import decimal import decimal
from typing import Any, Optional, Type
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Json
from ormar import ModelDefinitionError # noqa I101
from ormar.fields.base import BaseField # noqa I101 from ormar.fields.base import BaseField # noqa I101
from ormar.fields.decorators import RequiredParams
@RequiredParams("length") def is_field_nullable(
class String(BaseField): nullable: Optional[bool], default: Any, server_default: Any
__type__ = str ) -> bool:
if nullable is None:
def get_column_type(self) -> sqlalchemy.Column: return default is not None or server_default is not None
return sqlalchemy.String(self.length) return nullable
class Integer(BaseField): class ModelFieldFactory:
__type__ = int _bases = None
_type = None
def get_column_type(self) -> sqlalchemy.Column: def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]:
cls.validate(**kwargs)
default = kwargs.pop("default", None)
server_default = kwargs.pop("server_default", None)
nullable = kwargs.pop("nullable", None)
namespace = dict(
__type__=cls._type,
name=kwargs.pop("name", None),
primary_key=kwargs.pop("primary_key", False),
default=default,
server_default=server_default,
nullable=is_field_nullable(nullable, default, server_default),
index=kwargs.pop("index", False),
unique=kwargs.pop("unique", False),
pydantic_only=kwargs.pop("pydantic_only", False),
autoincrement=kwargs.pop("autoincrement", False),
column_type=cls.get_column_type(**kwargs),
**kwargs
)
return type(cls.__name__, cls._bases, namespace)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover
return None
@classmethod
def validate(cls, **kwargs: Any) -> None: # pragma no cover
pass
class String(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
def __new__(
cls,
*,
allow_blank: bool = False,
strip_whitespace: bool = False,
min_length: int = None,
max_length: int = None,
curtail_length: int = None,
regex: str = None,
**kwargs: Any
) -> Type[str]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.String(length=kwargs.get("max_length"))
@classmethod
def validate(cls, **kwargs: Any) -> None:
max_length = kwargs.get("max_length", None)
if max_length is None or max_length <= 0:
raise ModelDefinitionError(
"Parameter max_length is required for field String"
)
class Integer(ModelFieldFactory):
_bases = (pydantic.ConstrainedInt, BaseField)
_type = int
def __new__(
cls,
*,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[int]:
autoincrement = kwargs.pop("autoincrement", None)
autoincrement = (
autoincrement
if autoincrement is not None
else kwargs.get("primary_key", False)
)
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Integer() return sqlalchemy.Integer()
class Text(BaseField): class Text(ModelFieldFactory):
__type__ = str _bases = (pydantic.ConstrainedStr, BaseField)
_type = str
def get_column_type(self) -> sqlalchemy.Column: def __new__(
cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any
) -> Type[str]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Text() return sqlalchemy.Text()
class Float(BaseField): class Float(ModelFieldFactory):
__type__ = float _bases = (pydantic.ConstrainedFloat, BaseField)
_type = float
def get_column_type(self) -> sqlalchemy.Column: def __new__(
cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[int]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Float() return sqlalchemy.Float()
class Boolean(BaseField): class Boolean(ModelFieldFactory):
__type__ = bool _bases = (int, BaseField)
_type = bool
def get_column_type(self) -> sqlalchemy.Column: @classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Boolean() return sqlalchemy.Boolean()
class DateTime(BaseField): class DateTime(ModelFieldFactory):
__type__ = datetime.datetime _bases = (datetime.datetime, BaseField)
_type = datetime.datetime
def get_column_type(self) -> sqlalchemy.Column: @classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.DateTime() return sqlalchemy.DateTime()
class Date(BaseField): class Date(ModelFieldFactory):
__type__ = datetime.date _bases = (datetime.date, BaseField)
_type = datetime.date
def get_column_type(self) -> sqlalchemy.Column: @classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Date() return sqlalchemy.Date()
class Time(BaseField): class Time(ModelFieldFactory):
__type__ = datetime.time _bases = (datetime.time, BaseField)
_type = datetime.time
def get_column_type(self) -> sqlalchemy.Column: @classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.Time() return sqlalchemy.Time()
class JSON(BaseField): class JSON(ModelFieldFactory):
__type__ = Json _bases = (pydantic.Json, BaseField)
_type = pydantic.Json
def get_column_type(self) -> sqlalchemy.Column: @classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.JSON() return sqlalchemy.JSON()
class BigInteger(BaseField): class BigInteger(Integer):
__type__ = int _bases = (pydantic.ConstrainedInt, BaseField)
_type = int
def get_column_type(self) -> sqlalchemy.Column: @classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
return sqlalchemy.BigInteger() return sqlalchemy.BigInteger()
@RequiredParams("length", "precision") class Decimal(ModelFieldFactory):
class Decimal(BaseField): _bases = (pydantic.ConstrainedDecimal, BaseField)
__type__ = decimal.Decimal _type = decimal.Decimal
def get_column_type(self) -> sqlalchemy.Column: def __new__(
return sqlalchemy.DECIMAL(self.length, self.precision) cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
precision: int = None,
scale: int = None,
max_digits: int = None,
decimal_places: int = None,
**kwargs: Any
) -> Type[decimal.Decimal]:
kwargs = {
**kwargs,
**{
k: v
for k, v in locals().items()
if k not in ["cls", "__class__", "kwargs"]
},
}
return super().__new__(cls, **kwargs)
@classmethod
def get_column_type(cls, **kwargs: Any) -> Any:
precision = kwargs.get("precision")
scale = kwargs.get("scale")
return sqlalchemy.DECIMAL(precision=precision, scale=scale)
@classmethod
def validate(cls, **kwargs: Any) -> None:
precision = kwargs.get("precision")
scale = kwargs.get("scale")
if precision is None or precision < 0 or scale is None or scale < 0:
raise ModelDefinitionError(
"Parameters scale and precision are required for field Decimal"
)

View File

@ -2,10 +2,11 @@ import inspect
import json import json
import uuid import uuid
from typing import ( from typing import (
AbstractSet,
Any, Any,
Callable,
Dict, Dict,
List, List,
Mapping,
Optional, Optional,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
@ -21,18 +22,26 @@ from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.models.metaclass import ModelMetaclass from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.relations import RelationshipManager from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models.model import Model from ormar.models.model import Model
IntStr = Union[int, str]
DictStrAny = Dict[str, Any]
AbstractSetIntStr = AbstractSet[IntStr]
MappingIntStrAny = Mapping[IntStr, Any]
class FakePydantic(list, metaclass=ModelMetaclass):
class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
# FakePydantic inherits from list in order to be treated as # FakePydantic inherits from list in order to be treated as
# request.Body parameter in fastapi routes, # request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts # inheriting from pydantic.BaseModel causes metaclass conflicts
__slots__ = ("_orm_id", "_orm_saved")
__abstract__ = True __abstract__ = True
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]] __model_fields__: Dict[str, TypeVar[BaseField]]
__table__: sqlalchemy.Table __table__: sqlalchemy.Table
@ -43,63 +52,82 @@ class FakePydantic(list, metaclass=ModelMetaclass):
__metadata__: sqlalchemy.MetaData __metadata__: sqlalchemy.MetaData
__database__: databases.Database __database__: databases.Database
_orm_relationship_manager: RelationshipManager _orm_relationship_manager: RelationshipManager
Meta: ModelMeta
# noinspection PyMissingConstructor
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False
self.values: Optional[BaseModel] = None
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False)
pk_only = kwargs.pop("__pk_only__", False)
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
kwargs = { kwargs = {
k: self.__model_fields__[k].expand_relationship(v, self) k: self._convert_json(
k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps"
)
for k, v in kwargs.items() for k, v in kwargs.items()
} }
self.values = self.__pydantic_model__(**kwargs)
values, fields_set, validation_error = pydantic.validate_model(self, kwargs)
if validation_error and not pk_only:
raise validation_error
object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set)
# super().__init__(**kwargs)
# self.values = self.__pydantic_model__(**kwargs)
def __del__(self) -> None: def __del__(self) -> None:
self._orm_relationship_manager.deregister(self) self.Meta._orm_relationship_manager.deregister(self)
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
if key in self.__fields__: relation_key = self.get_name(title=True) + "_" + name
value = self._convert_json(key, value, op="dumps") if name in self.__slots__:
value = self.__model_fields__[key].expand_relationship(value, self) object.__setattr__(self, name, value)
elif name == "pk":
relation_key = self.get_name(title=True) + "_" + key object.__setattr__(self, self.Meta.pkname, value)
if not self._orm_relationship_manager.contains(relation_key, self): elif self.Meta._orm_relationship_manager.contains(relation_key, self):
setattr(self.values, key, value) self.Meta.model_fields[name].expand_relationship(value, self)
else: else:
super().__setattr__(key, value) value = (
self._convert_json(name, value, "dumps")
if name in self.__fields__
else value
)
super().__setattr__(name, value)
def __getattribute__(self, key: str) -> Any: def __getattribute__(self, item: str) -> Any:
if key != "__fields__" and key in self.__fields__: if item != "__fields__" and item in self.__fields__:
relation_key = self.get_name(title=True) + "_" + key related = self._extract_related_model_instead_of_field(item)
if self._orm_relationship_manager.contains(relation_key, self): if related:
return self._orm_relationship_manager.get(relation_key, self) return related
value = object.__getattribute__(self, item)
value = self._convert_json(item, value, "loads")
return value
return super().__getattribute__(item)
item = getattr(self.values, key, None) def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
item = self._convert_json(key, item, op="loads") return self._extract_related_model_instead_of_field(item)
return item
return super().__getattribute__(key)
def __eq__(self, other: "Model") -> bool: def _extract_related_model_instead_of_field(
return self.values.dict() == other.values.dict() self, item: str
) -> Optional[Union["Model", List["Model"]]]:
relation_key = self.get_name(title=True) + "_" + item
if self.Meta._orm_relationship_manager.contains(relation_key, self):
return self.Meta._orm_relationship_manager.get(relation_key, self)
def __same__(self, other: "Model") -> bool: def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover if self.__class__ != other.__class__: # pragma no cover
return False return False
return self._orm_id == other._orm_id or ( return (
self.values is not None and other.values is not None and self.pk == other.pk self._orm_id == other._orm_id
or self.__dict__ == other.__dict__
or (self.pk == other.pk and self.pk is not None)
) )
def __repr__(self) -> str: # pragma no cover
return self.values.__repr__()
@classmethod
def __get_validators__(cls) -> Callable: # pragma no cover
yield cls.__pydantic_model__.validate
@classmethod @classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str: def get_name(cls, title: bool = False, lower: bool = True) -> str:
name = cls.__name__ name = cls.__name__
@ -109,28 +137,50 @@ class FakePydantic(list, metaclass=ModelMetaclass):
name = name.title() name = name.title()
return name return name
@property
def pk(self) -> Any:
return getattr(self, self.Meta.pkname)
@property @property
def pk_column(self) -> sqlalchemy.Column: def pk_column(self) -> sqlalchemy.Column:
return self.__table__.primary_key.columns.values()[0] return self.Meta.table.primary_key.columns.values()[0]
@classmethod @classmethod
def pk_type(cls) -> Any: def pk_type(cls) -> Any:
return cls.__model_fields__[cls.__pkname__].__type__ return cls.Meta.model_fields[cls.Meta.pkname].__type__
def dict(self, nested=False) -> Dict: # noqa: A003 def dict( # noqa A003
dict_instance = self.values.dict() self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
skip_defaults: bool = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
nested: bool = False
) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict(
include=include,
exclude=self._exclude_related_names_not_required(nested),
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for field in self._extract_related_names(): for field in self._extract_related_names():
nested_model = getattr(self, field) nested_model = getattr(self, field)
if self.__model_fields__[field].virtual and nested:
if self.Meta.model_fields[field].virtual and nested:
continue continue
if isinstance(nested_model, list) and not isinstance( if isinstance(nested_model, list) and not isinstance(
nested_model, ormar.Model nested_model, ormar.Model
): ):
dict_instance[field] = [x.dict(nested=True) for x in nested_model] dict_instance[field] = [x.dict(nested=True) for x in nested_model]
else: elif nested_model is not None:
dict_instance[field] = ( dict_instance[field] = nested_model.dict(nested=True)
nested_model.dict(nested=True) if nested_model is not None else {}
)
return dict_instance return dict_instance
def from_dict(self, value_dict: Dict) -> None: def from_dict(self, value_dict: Dict) -> None:
@ -155,7 +205,7 @@ class FakePydantic(list, metaclass=ModelMetaclass):
return value return value
def _is_conversion_to_json_needed(self, column_name: str) -> bool: def _is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json
def _extract_own_model_fields(self) -> Dict: def _extract_own_model_fields(self) -> Dict:
related_names = self._extract_related_names() related_names = self._extract_related_names()
@ -165,9 +215,21 @@ class FakePydantic(list, metaclass=ModelMetaclass):
@classmethod @classmethod
def _extract_related_names(cls) -> Set: def _extract_related_names(cls) -> Set:
related_names = set() related_names = set()
for name, field in cls.__fields__.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field.type_) and issubclass( if inspect.isclass(field) and issubclass(field, ForeignKeyField):
field.type_, pydantic.BaseModel related_names.add(name)
return related_names
@classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested:
return cls._extract_related_names()
related_names = set()
for name, field in cls.Meta.model_fields.items():
if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and field.nullable
): ):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -175,13 +237,12 @@ class FakePydantic(list, metaclass=ModelMetaclass):
def _extract_model_db_fields(self) -> Dict: def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields() self_fields = self._extract_own_model_fields()
self_fields = { self_fields = {
k: v for k, v in self_fields.items() if k in self.__table__.columns k: v for k, v in self_fields.items() if k in self.Meta.table.columns
} }
for field in self._extract_related_names(): for field in self._extract_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None: if getattr(self, field) is not None:
self_fields[field] = getattr( self_fields[field] = getattr(getattr(self, field), target_pk_name)
getattr(self, field), self.__model_fields__[field].to.__pkname__
)
return self_fields return self_fields
@classmethod @classmethod
@ -196,18 +257,19 @@ class FakePydantic(list, metaclass=ModelMetaclass):
@classmethod @classmethod
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.__model_fields__.keys(): for field in one.Meta.model_fields.keys():
if isinstance(getattr(one, field), list) and not isinstance( current_field = getattr(one, field)
getattr(one, field), ormar.Model if isinstance(current_field, list) and not isinstance(
current_field, ormar.Model
):
setattr(other, field, current_field + getattr(other, field))
elif (
isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk
): ):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), ormar.Model):
if getattr(one, field).pk == getattr(other, field).pk:
setattr( setattr(
other, other,
field, field,
cls.merge_two_instances( cls.merge_two_instances(current_field, getattr(other, field)),
getattr(one, field), getattr(other, field)
),
) )
return other return other

View File

@ -1,12 +1,15 @@
import copy from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
import databases
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig, create_model from pydantic import BaseConfig
from pydantic.fields import ModelField from pydantic.fields import FieldInfo
from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar import ForeignKey, ModelDefinitionError # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset import QuerySet
from ormar.relations import RelationshipManager from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -15,16 +18,15 @@ if TYPE_CHECKING: # pragma no cover
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: class ModelMeta:
pydantic_fields = { tablename: str
field_name: ( table: sqlalchemy.Table
base_field.__type__, metadata: sqlalchemy.MetaData
... if base_field.is_required else base_field.default_value, database: databases.Database
) columns: List[sqlalchemy.Column]
for field_name, base_field in object_dict.items() pkname: str
if isinstance(base_field, BaseField) model_fields: Dict[str, Union[BaseField, ForeignKey]]
} _orm_relationship_manager: RelationshipManager
return pydantic_fields
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
@ -41,8 +43,8 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) ->
def expand_reverse_relationships(model: Type["Model"]) -> None: def expand_reverse_relationships(model: Type["Model"]) -> None:
for model_field in model.__model_fields__.values(): for model_field in model.Meta.model_fields.values():
if isinstance(model_field, ForeignKey): if issubclass(model_field, ForeignKeyField):
child_model_name = model_field.related_name or model.get_name() + "s" child_model_name = model_field.related_name or model.get_name() + "s"
parent_model = model_field.to parent_model = model_field.to
child = model child = model
@ -56,13 +58,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
def register_reverse_model_fields( def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str model: Type["Model"], child: Type["Model"], child_model_name: str
) -> None: ) -> None:
model.__fields__[child_model_name] = ModelField( model.Meta.model_fields[child_model_name] = ForeignKey(
name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__,
)
model.__model_fields__[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True child, name=child_model_name, virtual=True
) )
@ -74,71 +70,96 @@ def sqlalchemy_columns_from_model_fields(
pkname = None pkname = None
model_fields = { model_fields = {
field_name: field field_name: field
for field_name, field in object_dict.items() for field_name, field in object_dict["__annotations__"].items()
if isinstance(field, BaseField) if issubclass(field, BaseField)
} }
for field_name, field in model_fields.items(): for field_name, field in model_fields.items():
if field.primary_key: if field.primary_key:
if pkname is not None: if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.") raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only:
raise ModelDefinitionError("Primary key column cannot be pydantic only")
pkname = field_name pkname = field_name
if not field.pydantic_only: if not field.pydantic_only:
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
if isinstance(field, ForeignKey): if issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field, name) register_relation_on_build(table_name, field, name)
return pkname, columns, model_fields return pkname, columns, model_fields
def populate_pydantic_default_values(attrs: Dict) -> Dict:
for field, type_ in attrs["__annotations__"].items():
if issubclass(type_, BaseField):
if type_.name is None:
type_.name = field
def_value = type_.default_value()
curr_def_value = attrs.get(field, "NONE")
if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
attrs[field] = def_value
elif curr_def_value == "NONE" and type_.nullable:
attrs[field] = FieldInfo(default=None)
return attrs
def get_pydantic_base_orm_config() -> Type[BaseConfig]: def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class Config(BaseConfig): class Config(BaseConfig):
orm_mode = True orm_mode = True
arbitrary_types_allowed = True
# extra = Extra.allow
return Config return Config
class ModelMetaclass(type): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
attrs["Config"] = get_pydantic_base_orm_config()
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
if attrs.get("__abstract__"): if hasattr(new_model, "Meta"):
return new_model
tablename = attrs.get("__tablename__", name.lower() + "s") annotations = attrs.get("__annotations__") or new_model.__annotations__
attrs["__tablename__"] = tablename attrs["__annotations__"] = annotations
metadata = attrs["__metadata__"] attrs = populate_pydantic_default_values(attrs)
tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename or tablename
# sqlalchemy table creation # sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
name, attrs, tablename name, attrs, new_model.Meta.tablename
) )
attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
attrs["__columns__"] = columns if hasattr(new_model.Meta, "model_fields") and not pkname:
attrs["__pkname__"] = pkname model_fields = new_model.Meta.model_fields
for fieldname, field in new_model.Meta.model_fields.items():
if field.primary_key:
pkname = fieldname
columns = new_model.Meta.table.columns
if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table(
new_model.Meta.tablename, new_model.Meta.metadata, *columns
)
new_model.Meta.columns = columns
new_model.Meta.pkname = pkname
if not pkname: if not pkname:
raise ModelDefinitionError("Table has to have a primary key.") raise ModelDefinitionError("Table has to have a primary key.")
# pydantic model creation new_model.Meta.model_fields = model_fields
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
pydantic_model = create_model(
name, __config__=get_pydantic_base_orm_config(), **pydantic_fields
)
attrs["__pydantic_fields__"] = pydantic_fields
attrs["__pydantic_model__"] = pydantic_model
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
attrs["__model_fields__"] = model_fields
attrs["_orm_relationship_manager"] = relationship_manager
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
expand_reverse_relationships(new_model) expand_reverse_relationships(new_model)
new_model.Meta._orm_relationship_manager = relationship_manager
new_model.objects = QuerySet(new_model)
# breakpoint()
return new_model return new_model

View File

@ -7,9 +7,9 @@ from ormar.models import FakePydantic # noqa I100
class Model(FakePydantic): class Model(FakePydantic):
__abstract__ = True __abstract__ = False
objects = ormar.queryset.QuerySet() # objects = ormar.queryset.QuerySet()
@classmethod @classmethod
def from_row( def from_row(
@ -22,24 +22,24 @@ class Model(FakePydantic):
item = {} item = {}
select_related = select_related or [] select_related = select_related or []
table_prefix = cls._orm_relationship_manager.resolve_relation_join( table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join(
previous_table, cls.__table__.name previous_table, cls.Meta.table.name
) )
previous_table = cls.__table__.name previous_table = cls.Meta.table.name
for related in select_related: for related in select_related:
if "__" in related: if "__" in related:
first_part, remainder = related.split("__", 1) first_part, remainder = related.split("__", 1)
model_cls = cls.__model_fields__[first_part].to model_cls = cls.Meta.model_fields[first_part].to
child = model_cls.from_row( child = model_cls.from_row(
row, select_related=[remainder], previous_table=previous_table row, select_related=[remainder], previous_table=previous_table
) )
item[first_part] = child item[first_part] = child
else: else:
model_cls = cls.__model_fields__[related].to model_cls = cls.Meta.model_fields[related].to
child = model_cls.from_row(row, previous_table=previous_table) child = model_cls.from_row(row, previous_table=previous_table)
item[related] = child item[related] = child
for column in cls.__table__.columns: for column in cls.Meta.table.columns:
if column.name not in item: if column.name not in item:
item[column.name] = row[ item[column.name] = row[
f'{table_prefix + "_" if table_prefix else ""}{column.name}' f'{table_prefix + "_" if table_prefix else ""}{column.name}'
@ -47,22 +47,14 @@ class Model(FakePydantic):
return cls(**item) return cls(**item)
@property
def pk(self) -> str:
return getattr(self.values, self.__pkname__)
@pk.setter
def pk(self, value: Any) -> None:
setattr(self.values, self.__pkname__, value)
async def save(self) -> "Model": async def save(self) -> "Model":
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
if self.__model_fields__.get(self.__pkname__).autoincrement: if self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
self_fields.pop(self.__pkname__, None) self_fields.pop(self.Meta.pkname, None)
expr = self.__table__.insert() expr = self.Meta.table.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.__database__.execute(expr) item_id = await self.Meta.database.execute(expr)
self.pk = item_id setattr(self, self.Meta.pkname, item_id)
return self return self
async def update(self, **kwargs: Any) -> int: async def update(self, **kwargs: Any) -> int:
@ -71,23 +63,23 @@ class Model(FakePydantic):
self.from_dict(new_values) self.from_dict(new_values)
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
self_fields.pop(self.__pkname__) self_fields.pop(self.Meta.pkname)
expr = ( expr = (
self.__table__.update() self.Meta.table.update()
.values(**self_fields) .values(**self_fields)
.where(self.pk_column == getattr(self, self.__pkname__)) .where(self.pk_column == getattr(self, self.Meta.pkname))
) )
result = await self.__database__.execute(expr) result = await self.Meta.database.execute(expr)
return result return result
async def delete(self) -> int: async def delete(self) -> int:
expr = self.__table__.delete() expr = self.Meta.table.delete()
expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
result = await self.__database__.execute(expr) result = await self.Meta.database.execute(expr)
return result return result
async def load(self) -> "Model": async def load(self) -> "Model":
expr = self.__table__.select().where(self.pk_column == self.pk) expr = self.Meta.table.select().where(self.pk_column == self.pk)
row = await self.__database__.fetch_one(expr) row = await self.Meta.database.fetch_one(expr)
self.from_dict(dict(row)) self.from_dict(dict(row))
return self return self

View File

@ -32,7 +32,7 @@ class QueryClause:
self.filter_clauses = filter_clauses self.filter_clauses = filter_clauses
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.__table__ self.table = self.model_cls.Meta.table
def filter( # noqa: A003 def filter( # noqa: A003
self, **kwargs: Any self, **kwargs: Any
@ -41,7 +41,7 @@ class QueryClause:
select_related = list(self._select_related) select_related = list(self._select_related)
if kwargs.get("pk"): if kwargs.get("pk"):
pk_name = self.model_cls.__pkname__ pk_name = self.model_cls.Meta.pkname
kwargs[pk_name] = kwargs.pop("pk") kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items(): for key, value in kwargs.items():
@ -65,8 +65,8 @@ class QueryClause:
related_parts, select_related related_parts, select_related
) )
table = model_cls.__table__ table = model_cls.Meta.table
column = model_cls.__table__.columns[field_name] column = model_cls.Meta.table.columns[field_name]
else: else:
op = "exact" op = "exact"
@ -106,12 +106,12 @@ class QueryClause:
# Walk the relationships to the actual model class # Walk the relationships to the actual model class
# against which the comparison is being made. # against which the comparison is being made.
previous_table = model_cls.__tablename__ previous_table = model_cls.Meta.tablename
for part in related_parts: for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__ current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
manager = model_cls._orm_relationship_manager manager = model_cls.Meta._orm_relationship_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table) table_prefix = manager.resolve_relation_join(previous_table, current_table)
model_cls = model_cls.__model_fields__[part].to model_cls = model_cls.Meta.model_fields[part].to
previous_table = current_table previous_table = current_table
return select_related, table_prefix, model_cls return select_related, table_prefix, model_cls
@ -128,7 +128,7 @@ class QueryClause:
clause_text = str( clause_text = str(
clause.compile( clause.compile(
dialect=self.model_cls.__database__._backend._dialect, dialect=self.model_cls.Meta.database._backend._dialect,
compile_kwargs={"literal_binds": True}, compile_kwargs={"literal_binds": True},
) )
) )

View File

@ -4,8 +4,8 @@ import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar import ForeignKey
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -34,7 +34,7 @@ class Query:
self.filter_clauses = filter_clauses self.filter_clauses = filter_clauses
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.__table__ self.table = self.model_cls.Meta.table
self.auto_related = [] self.auto_related = []
self.used_aliases = [] self.used_aliases = []
@ -46,18 +46,18 @@ class Query:
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self.columns = list(self.table.columns) self.columns = list(self.table.columns)
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
self.select_from = self.table self.select_from = self.table
for key in self.model_cls.__model_fields__: # for key in self.model_cls.Meta.model_fields:
if ( # if (
not self.model_cls.__model_fields__[key].nullable # not self.model_cls.Meta.model_fields[key].nullable
and isinstance( # and isinstance(
self.model_cls.__model_fields__[key], ormar.fields.ForeignKey, # self.model_cls.Meta.model_fields[key], ForeignKeyField,
) # )
and key not in self._select_related # and key not in self._select_related
): # ):
self._select_related = [key] + self._select_related # self._select_related = [key] + self._select_related
start_params = JoinParameters( start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls self.model_cls, "", self.table.name, self.model_cls
@ -97,12 +97,12 @@ class Query:
@staticmethod @staticmethod
def _field_is_a_foreign_key_and_no_circular_reference( def _field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str field: Type[BaseField], field_name: str, rel_part: str
) -> bool: ) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part return issubclass(field, ForeignKeyField) and field_name not in rel_part
def _field_qualifies_to_deeper_search( def _field_qualifies_to_deeper_search(
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
) -> bool: ) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1]) prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any( partial_match = any(
@ -126,25 +126,26 @@ class Query:
def _build_join_parameters( def _build_join_parameters(
self, part: str, join_params: JoinParameters self, part: str, join_params: JoinParameters
) -> JoinParameters: ) -> JoinParameters:
model_cls = join_params.model_cls.__model_fields__[part].to model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.__table__.name to_table = model_cls.Meta.table.name
alias = model_cls._orm_relationship_manager.resolve_relation_join( alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join(
join_params.from_table, to_table join_params.from_table, to_table
) )
if alias not in self.used_aliases: if alias not in self.used_aliases:
if join_params.prev_model.__model_fields__[part].virtual: if join_params.prev_model.Meta.model_fields[part].virtual:
to_key = next( to_key = next(
( (
v v
for k, v in model_cls.__model_fields__.items() for k, v in model_cls.Meta.model_fields.items()
if isinstance(v, ForeignKey) and v.to == join_params.prev_model if issubclass(v, ForeignKeyField)
and v.to == join_params.prev_model
), ),
None, None,
).name ).name
from_key = model_cls.__pkname__ from_key = model_cls.Meta.pkname
else: else:
to_key = model_cls.__pkname__ to_key = model_cls.Meta.pkname
from_key = part from_key = part
on_clause = self.on_clause( on_clause = self.on_clause(
@ -157,8 +158,8 @@ class Query:
self.select_from = sqlalchemy.sql.outerjoin( self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause self.select_from, target_table, on_clause
) )
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}"))
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) self.columns.extend(self.prefixed_columns(alias, model_cls.Meta.table))
self.used_aliases.append(alias) self.used_aliases.append(alias)
previous_alias = alias previous_alias = alias
@ -173,18 +174,24 @@ class Query:
nested: bool = False, nested: bool = False,
parent_virtual: bool = False, parent_virtual: bool = False,
) -> None: ) -> None:
for field_name, field in prev_model.__model_fields__.items(): for field_name, field in prev_model.Meta.model_fields.items():
if self._field_is_a_foreign_key_and_no_circular_reference( if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part field, field_name, rel_part
): ):
rel_part = field_name if not rel_part else rel_part + "__" + field_name rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable: if not field.nullable:
if rel_part not in self._select_related: if rel_part not in self._select_related:
self.auto_related.append("__".join(rel_part.split("__")[:-1])) new_related = (
"__".join(rel_part.split("__")[:-1])
if len(rel_part.split("__")) > 1
else rel_part
)
self.auto_related.append(new_related)
rel_part = "" rel_part = ""
elif self._field_qualifies_to_deeper_search( elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part field, parent_virtual, nested, rel_part
): ):
self._extract_auto_required_relations( self._extract_auto_required_relations(
prev_model=field.to, prev_model=field.to,
rel_part=rel_part, rel_part=rel_part,

View File

@ -33,11 +33,11 @@ class QuerySet:
@property @property
def database(self) -> databases.Database: def database(self) -> databases.Database:
return self.model_cls.__database__ return self.model_cls.Meta.database
@property @property
def table(self) -> sqlalchemy.Table: def table(self) -> sqlalchemy.Table:
return self.model_cls.__table__ return self.model_cls.Meta.table
def build_select_expression(self) -> sqlalchemy.sql.select: def build_select_expression(self) -> sqlalchemy.sql.select:
qry = Query( qry = Query(
@ -148,8 +148,8 @@ class QuerySet:
new_kwargs = dict(**kwargs) new_kwargs = dict(**kwargs)
# Remove primary key when None to prevent not null constraint in postgresql. # Remove primary key when None to prevent not null constraint in postgresql.
pkname = self.model_cls.__pkname__ pkname = self.model_cls.Meta.pkname
pk = self.model_cls.__model_fields__[pkname] pk = self.model_cls.Meta.model_fields[pkname]
if ( if (
pkname in new_kwargs pkname in new_kwargs
and new_kwargs.get(pkname) is None and new_kwargs.get(pkname) is None
@ -163,11 +163,11 @@ class QuerySet:
if isinstance(new_kwargs.get(field), ormar.Model): if isinstance(new_kwargs.get(field), ormar.Model):
new_kwargs[field] = getattr( new_kwargs[field] = getattr(
new_kwargs.get(field), new_kwargs.get(field),
self.model_cls.__model_fields__[field].to.__pkname__, self.model_cls.Meta.model_fields[field].to.Meta.pkname,
) )
else: else:
new_kwargs[field] = new_kwargs.get(field).get( new_kwargs[field] = new_kwargs.get(field).get(
self.model_cls.__model_fields__[field].to.__pkname__ self.model_cls.Meta.model_fields[field].to.Meta.pkname
) )
# Build the insert expression. # Build the insert expression.
@ -176,5 +176,6 @@ class QuerySet:
# Execute the insert, and return a new model instance. # Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs) instance = self.model_cls(**kwargs)
instance.pk = await self.database.execute(expr) pk = await self.database.execute(expr)
setattr(instance, self.model_cls.Meta.pkname, pk)
return instance return instance

View File

@ -5,7 +5,7 @@ from random import choices
from typing import List, TYPE_CHECKING, Union from typing import List, TYPE_CHECKING, Union
from weakref import proxy from weakref import proxy
from ormar import ForeignKey from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import FakePydantic, Model from ormar.models import FakePydantic, Model
@ -21,14 +21,18 @@ class RelationshipManager:
self._aliases = dict() self._aliases = dict()
def add_relation_type( def add_relation_type(
self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str self,
relations_key: str,
reverse_key: str,
field: ForeignKeyField,
table_name: str,
) -> None: ) -> None:
if relations_key not in self._relations: if relations_key not in self._relations:
self._relations[relations_key] = {"type": "primary"} self._relations[relations_key] = {"type": "primary"}
self._aliases[f"{table_name}_{field.to.__tablename__}"] = get_table_alias() self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
if reverse_key not in self._relations: if reverse_key not in self._relations:
self._relations[reverse_key] = {"type": "reverse"} self._relations[reverse_key] = {"type": "reverse"}
self._aliases[f"{field.to.__tablename__}_{table_name}"] = get_table_alias() self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
def deregister(self, model: "FakePydantic") -> None: def deregister(self, model: "FakePydantic") -> None:
for rel_type in self._relations.keys(): for rel_type in self._relations.keys():

0
scripts/publish.sh Normal file → Executable file
View File

View File

@ -16,17 +16,19 @@ def time():
class Example(ormar.Model): class Example(ormar.Model):
__tablename__ = "example" class Meta:
__metadata__ = metadata tablename = "example"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
created = ormar.DateTime(default=datetime.datetime.now) name: ormar.String(max_length=200, default='aaa')
created_day = ormar.Date(default=datetime.date.today) created: ormar.DateTime(default=datetime.datetime.now)
created_time = ormar.Time(default=time) created_day: ormar.Date(default=datetime.date.today)
description = ormar.Text(nullable=True) created_time: ormar.Time(default=time)
value = ormar.Float(nullable=True) description: ormar.Text(nullable=True)
data = ormar.JSON(default={}) value: ormar.Float(nullable=True)
data: ormar.JSON(default={})
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")

View File

@ -13,22 +13,24 @@ metadata = sqlalchemy.MetaData()
class Category(ormar.Model): class Category(ormar.Model):
__tablename__ = "categories" class Meta:
__metadata__ = metadata tablename = "categories"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Item(ormar.Model): class Item(ormar.Model):
__tablename__ = "items" class Meta:
__metadata__ = metadata tablename = "items"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
category = ormar.ForeignKey(Category, nullable=True) category: ormar.ForeignKey(Category, nullable=True)
@app.post("/items/", response_model=Item) @app.post("/items/", response_model=Item)

View File

@ -1,9 +1,11 @@
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from pydantic import ValidationError
import ormar import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -11,62 +13,68 @@ metadata = sqlalchemy.MetaData()
class Album(ormar.Model): class Album(ormar.Model):
__tablename__ = "album" class Meta:
__metadata__ = metadata tablename = "albums"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Track(ormar.Model): class Track(ormar.Model):
__tablename__ = "track" class Meta:
__metadata__ = metadata tablename = "tracks"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
album = ormar.ForeignKey(Album) album: ormar.ForeignKey(Album)
title = ormar.String(length=100) title: ormar.String(max_length=100)
position = ormar.Integer() position: ormar.Integer()
class Cover(ormar.Model): class Cover(ormar.Model):
__tablename__ = "covers" class Meta:
__metadata__ = metadata tablename = "covers"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
album = ormar.ForeignKey(Album, related_name="cover_pictures") album: ormar.ForeignKey(Album, related_name="cover_pictures")
title = ormar.String(length=100) title: ormar.String(max_length=100)
class Organisation(ormar.Model): class Organisation(ormar.Model):
__tablename__ = "org" class Meta:
__metadata__ = metadata tablename = "org"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
ident = ormar.String(length=100) ident: ormar.String(max_length=100)
class Team(ormar.Model): class Team(ormar.Model):
__tablename__ = "team" class Meta:
__metadata__ = metadata tablename = "teams"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
org = ormar.ForeignKey(Organisation) org: ormar.ForeignKey(Organisation)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Member(ormar.Model): class Member(ormar.Model):
__tablename__ = "member" class Meta:
__metadata__ = metadata tablename = "members"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
team = ormar.ForeignKey(Team) team: ormar.ForeignKey(Team)
email = ormar.String(length=100) email: ormar.String(max_length=100)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -113,6 +121,7 @@ async def test_model_crud():
track = await Track.objects.get(title="The Bird") track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model)
assert track.album.name is None assert track.album.name is None
await track.album.load() await track.album.load()
assert track.album.name == "Malibu" assert track.album.name == "Malibu"
@ -124,6 +133,8 @@ async def test_model_crud():
assert album1.pk == 1 assert album1.pk == 1
assert album1.tracks is None assert album1.tracks is None
await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_select_related(): async def test_select_related():

View File

@ -1,4 +1,5 @@
import datetime import datetime
import decimal
import pydantic import pydantic
import pytest import pytest
@ -12,19 +13,21 @@ metadata = sqlalchemy.MetaData()
class ExampleModel(Model): class ExampleModel(Model):
__tablename__ = "example" class Meta:
__metadata__ = metadata tablename = "example"
test = fields.Integer(primary_key=True) metadata = metadata
test_string = fields.String(length=250)
test_text = fields.Text(default="") test: fields.Integer(primary_key=True)
test_bool = fields.Boolean(nullable=False) test_string: fields.String(max_length=250)
test_float = fields.Float() test_text: fields.Text(default="")
test_datetime = fields.DateTime(default=datetime.datetime.now) test_bool: fields.Boolean(nullable=False)
test_date = fields.Date(default=datetime.date.today) test_float: fields.Float() = None
test_time = fields.Time(default=datetime.time) test_datetime: fields.DateTime(default=datetime.datetime.now)
test_json = fields.JSON(default={}) test_date: fields.Date(default=datetime.date.today)
test_bigint = fields.BigInteger(default=0) test_time: fields.Time(default=datetime.time)
test_decimal = fields.Decimal(length=10, precision=2) test_json: fields.JSON(default={})
test_bigint: fields.BigInteger(default=0)
test_decimal: fields.Decimal(scale=10, precision=2)
fields_to_check = [ fields_to_check = [
@ -41,15 +44,17 @@ fields_to_check = [
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example2" class Meta:
__metadata__ = metadata tablename = "example2"
test = fields.Integer(primary_key=True) metadata = metadata
test_string = fields.String(length=250)
test: fields.Integer(primary_key=True)
test_string: fields.String(max_length=250)
@pytest.fixture() @pytest.fixture()
def example(): def example():
return ExampleModel(pk=1, test_string="test", test_bool=True) return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5))
def test_not_nullable_field_is_required(): def test_not_nullable_field_is_required():
@ -70,8 +75,18 @@ def test_model_attribute_access(example):
example.test = 12 example.test = 12
assert example.test == 12 assert example.test == 12
example._orm_saved = True
assert example._orm_saved
def test_model_attribute_json_access(example):
example.test_json = dict(aa=12)
assert example.test_json == dict(aa=12)
def test_non_existing_attr(example):
with pytest.raises(ValueError):
example.new_attr = 12 example.new_attr = 12
assert "new_attr" in example.__dict__
def test_primary_key_access_and_setting(example): def test_primary_key_access_and_setting(example):
@ -83,60 +98,65 @@ def test_primary_key_access_and_setting(example):
def test_pydantic_model_is_created(example): def test_pydantic_model_is_created(example):
assert issubclass(example.values.__class__, pydantic.BaseModel) assert issubclass(example.__class__, pydantic.BaseModel)
assert all([field in example.values.__fields__ for field in fields_to_check]) assert all([field in example.__fields__ for field in fields_to_check])
assert example.values.test == 1 assert example.test == 1
def test_sqlalchemy_table_is_created(example): def test_sqlalchemy_table_is_created(example):
assert issubclass(example.__table__.__class__, sqlalchemy.Table) assert issubclass(example.Meta.table.__class__, sqlalchemy.Table)
assert all([field in example.__table__.columns for field in fields_to_check]) assert all([field in example.Meta.table.columns for field in fields_to_check])
def test_no_pk_in_model_definition(): def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example3" class Meta:
__metadata__ = metadata tablename = "example3"
test_string = fields.String(length=250) metadata = metadata
test_string: fields.String(max_length=250)
def test_two_pks_in_model_definition(): def test_two_pks_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example3" class Meta:
__metadata__ = metadata tablename = "example3"
id = fields.Integer(primary_key=True) metadata = metadata
test_string = fields.String(length=250, primary_key=True)
id: fields.Integer(primary_key=True)
test_string: fields.String(max_length=250, primary_key=True)
def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" class Meta:
__metadata__ = metadata tablename = "example4"
test = fields.Integer(primary_key=True, pydantic_only=True) metadata = metadata
test: fields.Integer(primary_key=True, pydantic_only=True)
def test_decimal_error_in_model_definition(): def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" class Meta:
__metadata__ = metadata tablename = "example5"
test = fields.Decimal(primary_key=True) metadata = metadata
test: fields.Decimal(primary_key=True)
def test_string_error_in_model_definition(): def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" class Meta:
__metadata__ = metadata tablename = "example6"
test = fields.String(primary_key=True) metadata = metadata
test: fields.String(primary_key=True)
def test_json_conversion_in_model(): def test_json_conversion_in_model():

View File

@ -1,4 +1,5 @@
import databases import databases
import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
@ -10,24 +11,36 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
class User(ormar.Model): class JsonSample(ormar.Model):
__tablename__ = "users" class Meta:
__metadata__ = metadata tablename = "jsons"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) test_json: ormar.JSON(nullable=True)
class User(ormar.Model):
class Meta:
tablename = "users"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100, default='')
class Product(ormar.Model): class Product(ormar.Model):
__tablename__ = "product" class Meta:
__metadata__ = metadata tablename = "product"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
rating = ormar.Integer(minimum=1, maximum=5) rating: ormar.Integer(minimum=1, maximum=5)
in_stock = ormar.Boolean(default=False) in_stock: ormar.Boolean(default=False)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -39,12 +52,12 @@ def create_test_database():
def test_model_class(): def test_model_class():
assert list(User.__model_fields__.keys()) == ["id", "name"] assert list(User.Meta.model_fields.keys()) == ["id", "name"]
assert isinstance(User.__model_fields__["id"], ormar.Integer) assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt)
assert User.__model_fields__["id"].primary_key is True assert User.Meta.model_fields["id"].primary_key is True
assert isinstance(User.__model_fields__["name"], ormar.String) assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr)
assert User.__model_fields__["name"].length == 100 assert User.Meta.model_fields["name"].max_length == 100
assert isinstance(User.__table__, sqlalchemy.Table) assert isinstance(User.Meta.table, sqlalchemy.Table)
def test_model_pk(): def test_model_pk():
@ -53,6 +66,18 @@ def test_model_pk():
assert user.id == 1 assert user.id == 1
@pytest.mark.asyncio
async def test_json_column():
async with database:
await JsonSample.objects.create(test_json=dict(aa=12))
await JsonSample.objects.create(test_json='{"aa": 12}')
items = await JsonSample.objects.all()
assert len(items) == 2
assert items[0].test_json == dict(aa=12)
assert items[1].test_json == dict(aa=12)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database:

View File

@ -30,22 +30,24 @@ async def shutdown() -> None:
class Category(ormar.Model): class Category(ormar.Model):
__tablename__ = "categories" class Meta:
__metadata__ = metadata tablename = "categories"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Item(ormar.Model): class Item(ormar.Model):
__tablename__ = "items" class Meta:
__metadata__ = metadata tablename = "items"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
category = ormar.ForeignKey(Category, nullable=True) category: ormar.ForeignKey(Category, nullable=True)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -59,19 +61,19 @@ def create_test_database():
@app.get("/items/", response_model=List[Item]) @app.get("/items/", response_model=List[Item])
async def get_items(): async def get_items():
items = await Item.objects.select_related("category").all() items = await Item.objects.select_related("category").all()
return [item.dict() for item in items] return items
@app.post("/items/", response_model=Item) @app.post("/items/", response_model=Item)
async def create_item(item: Item): async def create_item(item: Item):
item = await Item.objects.create(**item.dict()) await item.save()
return item.dict() return item
@app.post("/categories/", response_model=Category) @app.post("/categories/", response_model=Category)
async def create_category(category: Category): async def create_category(category: Category):
category = await Category.objects.create(**category.dict()) await category.save()
return category.dict() return category
@app.put("/items/{item_id}") @app.put("/items/{item_id}")

View File

@ -12,53 +12,58 @@ metadata = sqlalchemy.MetaData()
class Department(ormar.Model): class Department(ormar.Model):
__tablename__ = "departments" class Meta:
__metadata__ = metadata tablename = "departments"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True, autoincrement=False) id: ormar.Integer(primary_key=True, autoincrement=False)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class SchoolClass(ormar.Model): class SchoolClass(ormar.Model):
__tablename__ = "schoolclasses" class Meta:
__metadata__ = metadata tablename = "schoolclasses"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
department = ormar.ForeignKey(Department, nullable=False) department: ormar.ForeignKey(Department, nullable=False)
class Category(ormar.Model): class Category(ormar.Model):
__tablename__ = "categories" class Meta:
__metadata__ = metadata tablename = "categories"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Student(ormar.Model): class Student(ormar.Model):
__tablename__ = "students" class Meta:
__metadata__ = metadata tablename = "students"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
schoolclass = ormar.ForeignKey(SchoolClass) schoolclass: ormar.ForeignKey(SchoolClass)
category = ormar.ForeignKey(Category, nullable=True) category: ormar.ForeignKey(Category, nullable=True)
class Teacher(ormar.Model): class Teacher(ormar.Model):
__tablename__ = "teachers" class Meta:
__metadata__ = metadata tablename = "teachers"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
schoolclass = ormar.ForeignKey(SchoolClass) schoolclass: ormar.ForeignKey(SchoolClass)
category = ormar.ForeignKey(Category, nullable=True) category: ormar.ForeignKey(Category, nullable=True)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -71,6 +76,7 @@ def event_loop():
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
async def create_test_database(): async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine) metadata.create_all(engine)
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)