mostly working

This commit is contained in:
collerek
2020-08-16 22:27:39 +02:00
parent b69ad226e6
commit a39179bc64
18 changed files with 988 additions and 536 deletions

BIN
.coverage

Binary file not shown.

View File

@ -1,6 +1,8 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING from typing import Any, Dict, List, Optional, TYPE_CHECKING
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Field
from ormar import ModelDefinitionError # noqa I101 from ormar import ModelDefinitionError # noqa I101
@ -8,75 +10,71 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model
def prepare_validator(type_):
def validate_model_field(value):
return isinstance(value, type_)
return validate_model_field
class BaseField: class BaseField:
__type__ = None __type__ = None
def __init__(self, **kwargs: Any) -> None: column_type: sqlalchemy.Column
self.name = None constraints: List = []
self._populate_from_kwargs(kwargs)
def _populate_from_kwargs(self, kwargs: Dict) -> None: primary_key: bool
self.primary_key = kwargs.pop("primary_key", False) autoincrement: bool
self.autoincrement = kwargs.pop( nullable: bool
"autoincrement", self.primary_key and self.__type__ == int index: bool
) unique: bool
pydantic_only: bool
self.nullable = kwargs.pop("nullable", not self.primary_key) default: Any
self.default = kwargs.pop("default", None) server_default: Any
self.server_default = kwargs.pop("server_default", None)
self.index = kwargs.pop("index", None) @classmethod
self.unique = kwargs.pop("unique", None) def is_required(cls) -> bool:
self.pydantic_only = kwargs.pop("pydantic_only", False)
if self.pydantic_only and self.primary_key:
raise ModelDefinitionError("Primary key column cannot be pydantic only.")
@property
def is_required(self) -> bool:
return ( return (
not self.nullable and not self.has_default and not self.is_auto_primary_key not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key()
) )
@property @classmethod
def default_value(self) -> Any: def default_value(cls):
default = self.default if cls.is_auto_primary_key():
return default() if callable(default) else default return Field(default=None)
if cls.has_default():
default = cls.default if cls.default is not None else cls.server_default
if callable(default):
return Field(default_factory=default)
else:
return Field(default=default)
return None
@property @classmethod
def has_default(self) -> bool: def has_default(cls):
return self.default is not None or self.server_default is not None return cls.default is not None or cls.server_default is not None
@property @classmethod
def is_auto_primary_key(self) -> bool: def is_auto_primary_key(cls) -> bool:
if self.primary_key: if cls.primary_key:
return self.autoincrement return cls.autoincrement
return False return False
def get_column(self, name: str = None) -> sqlalchemy.Column: @classmethod
self.name = name def get_column(cls, name: str) -> sqlalchemy.Column:
constraints = self.get_constraints()
return sqlalchemy.Column( return sqlalchemy.Column(
self.name, name,
self.get_column_type(), cls.column_type,
*constraints, *cls.constraints,
primary_key=self.primary_key, primary_key=cls.primary_key,
autoincrement=self.autoincrement, nullable=cls.nullable and not cls.primary_key,
nullable=self.nullable, index=cls.index,
index=self.index, unique=cls.unique,
unique=self.unique, default=cls.default,
default=self.default, server_default=cls.server_default,
server_default=self.server_default,
) )
def get_column_type(self) -> sqlalchemy.types.TypeEngine: @classmethod
raise NotImplementedError() # pragma: no cover def expand_relationship(cls, value: Any, child: "Model") -> Any:
def get_constraints(self) -> Optional[List]:
return []
def expand_relationship(self, value: Any, child: "Model") -> Any:
return value return value
def __repr__(self): # pragma no cover
return str(self.__dict__)

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable
import sqlalchemy import sqlalchemy
from pydantic import BaseModel from pydantic import BaseModel
@ -13,87 +13,115 @@ if TYPE_CHECKING: # pragma no cover
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
init_dict = { init_dict = {
**{fk.__pkname__: pk or -1}, **{fk.Meta.pkname: pk or -1,
'__pk_only__': True},
**{ **{
k: create_dummy_instance(v.to) k: create_dummy_instance(v.to)
for k, v in fk.__model_fields__.items() for k, v in fk.Meta.model_fields.items()
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual
}, },
} }
return fk(**init_dict) return fk(**init_dict)
class ForeignKey(BaseField): def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True,
def __init__( related_name: str = None,
self, virtual: bool = False,
to: Type["Model"], ) -> Type[object]:
name: str = None, fk_string = to.Meta.tablename + "." + to.Meta.pkname
related_name: str = None, to_field = to.__fields__[to.Meta.pkname]
nullable: bool = True, namespace = dict(
virtual: bool = False, to=to,
) -> None: name=name,
super().__init__(nullable=nullable, name=name) nullable=nullable,
self.virtual = virtual constraints=[sqlalchemy.schema.ForeignKey(fk_string)],
self.related_name = related_name unique=unique,
self.to = to column_type=to_field.type_.column_type,
related_name=related_name,
virtual=virtual,
primary_key=False,
index=False,
pydantic_only=False,
default=None,
server_default=None
)
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
class ForeignKeyField(BaseField):
to: Type["Model"]
related_name: str
virtual: bool
@classmethod
def __get_validators__(cls) -> Callable:
yield cls.validate
@classmethod
def validate(cls, v: Any) -> Any:
return v
@property @property
def __type__(self) -> Type[BaseModel]: def __type__(self) -> Type[BaseModel]:
return self.to.__pydantic_model__ return self.to.__pydantic_model__
def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]: @classmethod
fk_string = self.to.__tablename__ + "." + self.to.__pkname__ def get_column_type(cls) -> sqlalchemy.Column:
return [sqlalchemy.schema.ForeignKey(fk_string)] to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname]
return to_column.column_type
def get_column_type(self) -> sqlalchemy.Column:
to_column = self.to.__model_fields__[self.to.__pkname__]
return to_column.get_column_type()
@classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(
self, value: List, child: "Model" cls, value: List, child: "Model"
) -> Union["Model", List["Model"]]: ) -> Union["Model", List["Model"]]:
return [self.expand_relationship(val, child) for val in value] return [cls.expand_relationship(val, child) for val in value]
def _register_existing_model(self, value: "Model", child: "Model") -> "Model": @classmethod
self.register_relation(value, child) def _register_existing_model(cls, value: "Model", child: "Model") -> "Model":
cls.register_relation(value, child)
return value return value
def _construct_model_from_dict(self, value: dict, child: "Model") -> "Model": @classmethod
model = self.to(**value) def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model":
self.register_relation(model, child) model = cls.to(**value)
cls.register_relation(model, child)
return model return model
def _construct_model_from_pk(self, value: Any, child: "Model") -> "Model": @classmethod
if not isinstance(value, self.to.pk_type()): def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model":
if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError( raise RelationshipInstanceError(
f"Relationship error - ForeignKey {self.to.__name__} " f"Relationship error - ForeignKey {cls.to.__name__} "
f"is of type {self.to.pk_type()} " f"is of type {cls.to.pk_type()} "
f"while {type(value)} passed as a parameter." f"while {type(value)} passed as a parameter."
) )
model = create_dummy_instance(fk=self.to, pk=value) model = create_dummy_instance(fk=cls.to, pk=value)
self.register_relation(model, child) cls.register_relation(model, child)
return model return model
def register_relation(self, model: "Model", child: "Model") -> None: @classmethod
child_model_name = self.related_name or child.get_name() def register_relation(cls, model: "Model", child: "Model") -> None:
model._orm_relationship_manager.add_relation( child_model_name = cls.related_name or child.get_name()
model, child, child_model_name, virtual=self.virtual model.Meta._orm_relationship_manager.add_relation(
model, child, child_model_name, virtual=cls.virtual
) )
@classmethod
def expand_relationship( def expand_relationship(
self, value: Any, child: "Model" cls, value: Any, child: "Model"
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if value is None: if value is None:
return None return None
constructors = { constructors = {
f"{self.to.__name__}": self._register_existing_model, f"{cls.to.__name__}": cls._register_existing_model,
"dict": self._construct_model_from_dict, "dict": cls._construct_model_from_dict,
"list": self._extract_model_from_sequence, "list": cls._extract_model_from_sequence,
} }
model = constructors.get( model = constructors.get(
value.__class__.__name__, self._construct_model_from_pk value.__class__.__name__, cls._construct_model_from_pk
)(value, child) )(value, child)
return model return model

View File

@ -1,87 +1,373 @@
import datetime import datetime
import decimal import decimal
import re
from typing import Type, Any, Optional
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Json from pydantic import Json
from ormar import ModelDefinitionError
from ormar.fields.base import BaseField # noqa I101 from ormar.fields.base import BaseField # noqa I101
from ormar.fields.decorators import RequiredParams
@RequiredParams("length") def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool:
class String(BaseField): if nullable is None:
__type__ = str return default is not None or server_default is not None
return False
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.String(self.length)
class Integer(BaseField): def String(
__type__ = int *,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
min_length: int = None,
max_length: int = None,
curtail_length: int = None,
regex: str = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[str]:
if max_length is None or max_length <= 0:
raise ModelDefinitionError(f'Parameter max_length is required for field String')
def get_column_type(self) -> sqlalchemy.Column: namespace = dict(
return sqlalchemy.Integer() __type__=str,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
allow_blank=allow_blank,
strip_whitespace=strip_whitespace,
min_length=min_length,
max_length=max_length,
curtail_length=curtail_length,
regex=regex and re.compile(regex),
column_type=sqlalchemy.String(length=max_length),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
return type("String", (pydantic.ConstrainedStr, BaseField), namespace)
class Text(BaseField): def Integer(
__type__ = str *,
name: str = None,
def get_column_type(self) -> sqlalchemy.Column: primary_key: bool = False,
return sqlalchemy.Text() autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[int]:
namespace = dict(
__type__=int,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.Integer(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key
)
return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace)
class Float(BaseField): def Text(
__type__ = float *,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[str]:
namespace = dict(
__type__=str,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
allow_blank=allow_blank,
strip_whitespace=strip_whitespace,
column_type=sqlalchemy.Text(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
def get_column_type(self) -> sqlalchemy.Column: return type("Text", (pydantic.ConstrainedStr, BaseField), namespace)
return sqlalchemy.Float()
class Boolean(BaseField): def Float(
__type__ = bool *,
name: str = None,
def get_column_type(self) -> sqlalchemy.Column: primary_key: bool = False,
return sqlalchemy.Boolean() nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[int]:
namespace = dict(
__type__=float,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.Float(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace)
class DateTime(BaseField): def Boolean(
__type__ = datetime.datetime *,
name: str = None,
def get_column_type(self) -> sqlalchemy.Column: primary_key: bool = False,
return sqlalchemy.DateTime() nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[bool]:
namespace = dict(
__type__=bool,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Boolean(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
return type("Boolean", (int, BaseField), namespace)
class Date(BaseField): def DateTime(
__type__ = datetime.date *,
name: str = None,
def get_column_type(self) -> sqlalchemy.Column: primary_key: bool = False,
return sqlalchemy.Date() nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[datetime.datetime]:
namespace = dict(
__type__=datetime.datetime,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.DateTime(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
return type("DateTime", (datetime.datetime, BaseField), namespace)
class Time(BaseField): def Date(
__type__ = datetime.time *,
name: str = None,
def get_column_type(self) -> sqlalchemy.Column: primary_key: bool = False,
return sqlalchemy.Time() nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[datetime.date]:
namespace = dict(
__type__=datetime.date,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Date(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
return type("Date", (datetime.date, BaseField), namespace)
class JSON(BaseField): def Time(
__type__ = Json *,
name: str = None,
def get_column_type(self) -> sqlalchemy.Column: primary_key: bool = False,
return sqlalchemy.JSON() nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[datetime.time]:
namespace = dict(
__type__=datetime.time,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.Time(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
return type("Time", (datetime.time, BaseField), namespace)
class BigInteger(BaseField): def JSON(
__type__ = int *,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[Json]:
namespace = dict(
__type__=pydantic.Json,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
column_type=sqlalchemy.JSON(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
def get_column_type(self) -> sqlalchemy.Column: return type("JSON", (pydantic.Json, BaseField), namespace)
return sqlalchemy.BigInteger()
@RequiredParams("length", "precision") def BigInteger(
class Decimal(BaseField): *,
__type__ = decimal.Decimal name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
) -> Type[int]:
namespace = dict(
__type__=int,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.BigInteger(),
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key
)
return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace)
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DECIMAL(self.length, self.precision) def Decimal(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
precision: int = None,
scale: int = None,
max_digits: int = None,
decimal_places: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
):
if precision is None or precision < 0 or scale is None or scale < 0:
raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal')
namespace = dict(
__type__=decimal.Decimal,
name=name,
primary_key=primary_key,
nullable=is_field_nullable(nullable, default, server_default),
index=index,
unique=unique,
ge=minimum,
le=maximum,
multiple_of=multiple_of,
column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale),
precision=precision,
scale=scale,
max_digits=max_digits,
decimal_places=decimal_places,
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
)
return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace)

View File

@ -11,7 +11,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
Union, Union, AbstractSet, Mapping,
) )
import databases import databases
@ -20,19 +20,27 @@ import sqlalchemy
from pydantic import BaseModel from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar import ForeignKey
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.models.metaclass import ModelMetaclass from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMetaclass, ModelMeta
from ormar.relations import RelationshipManager from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models.model import Model from ormar.models.model import Model
IntStr = Union[int, str]
DictStrAny = Dict[str, Any]
AbstractSetIntStr = AbstractSet[IntStr]
MappingIntStrAny = Mapping[IntStr, Any]
class FakePydantic(list, metaclass=ModelMetaclass):
class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
# FakePydantic inherits from list in order to be treated as # FakePydantic inherits from list in order to be treated as
# request.Body parameter in fastapi routes, # request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts # inheriting from pydantic.BaseModel causes metaclass conflicts
__abstract__ = True __slots__ = ('_orm_id', '_orm_saved')
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]] __model_fields__: Dict[str, TypeVar[BaseField]]
__table__: sqlalchemy.Table __table__: sqlalchemy.Table
@ -43,62 +51,88 @@ class FakePydantic(list, metaclass=ModelMetaclass):
__metadata__: sqlalchemy.MetaData __metadata__: sqlalchemy.MetaData
__database__: databases.Database __database__: databases.Database
_orm_relationship_manager: RelationshipManager _orm_relationship_manager: RelationshipManager
Meta: ModelMeta
# noinspection PyMissingConstructor
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False
self.values: Optional[BaseModel] = None
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False)
pk_only = kwargs.pop("__pk_only__", False)
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
kwargs = { kwargs = {
k: self.__model_fields__[k].expand_relationship(v, self) k: self.Meta.model_fields[k].expand_relationship(v, self)
for k, v in kwargs.items() for k, v in kwargs.items()
} }
self.values = self.__pydantic_model__(**kwargs)
values, fields_set, validation_error = pydantic.validate_model(
self, kwargs
)
if validation_error and not pk_only:
raise validation_error
object.__setattr__(self, '__dict__', values)
object.__setattr__(self, '__fields_set__', fields_set)
# super().__init__(**kwargs)
# self.values = self.__pydantic_model__(**kwargs)
def __del__(self) -> None: def __del__(self) -> None:
self._orm_relationship_manager.deregister(self) self.Meta._orm_relationship_manager.deregister(self)
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, name, value):
if key in self.__fields__: if name in self.__slots__:
value = self._convert_json(key, value, op="dumps") object.__setattr__(self, name, value)
value = self.__model_fields__[key].expand_relationship(value, self) elif name == 'pk':
object.__setattr__(self, self.Meta.pkname, value)
relation_key = self.get_name(title=True) + "_" + name
if self.Meta._orm_relationship_manager.contains(relation_key, self):
self.Meta.model_fields[name].expand_relationship(value, self)
return
super().__setattr__(name, value)
relation_key = self.get_name(title=True) + "_" + key def __getattr__(self, item):
if not self._orm_relationship_manager.contains(relation_key, self): relation_key = self.get_name(title=True) + "_" + item
setattr(self.values, key, value) if self.Meta._orm_relationship_manager.contains(relation_key, self):
else: return self.Meta._orm_relationship_manager.get(relation_key, self)
super().__setattr__(key, value)
def __getattribute__(self, key: str) -> Any: # def __setattr__(self, key: str, value: Any) -> None:
if key != "__fields__" and key in self.__fields__: # if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'):
relation_key = self.get_name(title=True) + "_" + key # return setattr(self, key, value)
if self._orm_relationship_manager.contains(relation_key, self): # # elif key in self._extract_related_names():
return self._orm_relationship_manager.get(relation_key, self) # # value = self._convert_json(key, value, op="dumps")
# # value = self.Meta.model_fields[key].expand_relationship(value, self)
# # relation_key = self.get_name(title=True) + "_" + key
# # if not self.Meta._orm_relationship_manager.contains(relation_key, self):
# # setattr(self.values, key, value)
# else:
# super().__setattr__(key, value)
item = getattr(self.values, key, None) # def __getattribute__(self, key: str) -> Any:
item = self._convert_json(key, item, op="loads") # if key != 'Meta' and key in self.Meta.model_fields:
return item # relation_key = self.get_name(title=True) + "_" + key
return super().__getattribute__(key) # if self.Meta._orm_relationship_manager.contains(relation_key, self):
# return self.Meta._orm_relationship_manager.get(relation_key, self)
def __eq__(self, other: "Model") -> bool: # item = getattr(self.__fields__, key, None)
return self.values.dict() == other.values.dict() # item = self._convert_json(key, item, op="loads")
# return item
# return super().__getattribute__(key)
def __same__(self, other: "Model") -> bool: def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover if self.__class__ != other.__class__: # pragma no cover
return False return False
return self._orm_id == other._orm_id or ( return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk self.__dict__ is not None and other.__dict__ is not None and self.pk == other.pk
) )
def __repr__(self) -> str: # pragma no cover # def __repr__(self) -> str: # pragma no cover
return self.values.__repr__() # return self.values.__repr__()
@classmethod # @classmethod
def __get_validators__(cls) -> Callable: # pragma no cover # def __get_validators__(cls) -> Callable: # pragma no cover
yield cls.__pydantic_model__.validate # yield cls.__pydantic_model__.validate
@classmethod @classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str: def get_name(cls, title: bool = False, lower: bool = True) -> str:
@ -109,25 +143,57 @@ class FakePydantic(list, metaclass=ModelMetaclass):
name = name.title() name = name.title()
return name return name
@property
def pk(self) -> Any:
return getattr(self, self.Meta.pkname)
@pk.setter
def pk(self, value: Any) -> None:
setattr(self, self.Meta.pkname, value)
@property @property
def pk_column(self) -> sqlalchemy.Column: def pk_column(self) -> sqlalchemy.Column:
return self.__table__.primary_key.columns.values()[0] return self.Meta.table.primary_key.columns.values()[0]
@classmethod @classmethod
def pk_type(cls) -> Any: def pk_type(cls) -> Any:
return cls.__model_fields__[cls.__pkname__].__type__ return cls.Meta.model_fields[cls.Meta.pkname].__type__
def dict(self, nested=False) -> Dict: # noqa: A003 def dict(
dict_instance = self.values.dict() self,
*,
include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None,
exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None,
by_alias: bool = False,
skip_defaults: bool = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
nested: bool = False
) -> 'DictStrAny': # noqa: A003
print('callin super', self.__class__)
print('to exclude', self._exclude_related_names_not_required(nested))
dict_instance = super().dict(include=include,
exclude=self._exclude_related_names_not_required(nested),
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
print('after super')
for field in self._extract_related_names(): for field in self._extract_related_names():
print(self.__class__, field, nested)
nested_model = getattr(self, field) nested_model = getattr(self, field)
if self.__model_fields__[field].virtual and nested:
if self.Meta.model_fields[field].virtual and nested:
continue continue
if isinstance(nested_model, list) and not isinstance( if isinstance(nested_model, list) and not isinstance(
nested_model, ormar.Model nested_model, ormar.Model
): ):
print('nested list')
dict_instance[field] = [x.dict(nested=True) for x in nested_model] dict_instance[field] = [x.dict(nested=True) for x in nested_model]
else: else:
print('instance')
dict_instance[field] = ( dict_instance[field] = (
nested_model.dict(nested=True) if nested_model is not None else {} nested_model.dict(nested=True) if nested_model is not None else {}
) )
@ -155,7 +221,7 @@ class FakePydantic(list, metaclass=ModelMetaclass):
return value return value
def _is_conversion_to_json_needed(self, column_name: str) -> bool: def _is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json
def _extract_own_model_fields(self) -> Dict: def _extract_own_model_fields(self) -> Dict:
related_names = self._extract_related_names() related_names = self._extract_related_names()
@ -165,22 +231,32 @@ class FakePydantic(list, metaclass=ModelMetaclass):
@classmethod @classmethod
def _extract_related_names(cls) -> Set: def _extract_related_names(cls) -> Set:
related_names = set() related_names = set()
for name, field in cls.__fields__.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field.type_) and issubclass( if inspect.isclass(field) and issubclass(
field.type_, pydantic.BaseModel field, ForeignKeyField
): ):
related_names.add(name) related_names.add(name)
return related_names return related_names
@classmethod
def _exclude_related_names_not_required(cls, nested:bool=False) -> Set:
if nested:
return cls._extract_related_names()
related_names = set()
for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable:
related_names.add(name)
return related_names
def _extract_model_db_fields(self) -> Dict: def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields() self_fields = self._extract_own_model_fields()
self_fields = { self_fields = {
k: v for k, v in self_fields.items() if k in self.__table__.columns k: v for k, v in self_fields.items() if k in self.Meta.table.columns
} }
for field in self._extract_related_names(): for field in self._extract_related_names():
if getattr(self, field) is not None: if getattr(self, field) is not None:
self_fields[field] = getattr( self_fields[field] = getattr(
getattr(self, field), self.__model_fields__[field].to.__pkname__ getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname
) )
return self_fields return self_fields
@ -196,9 +272,9 @@ class FakePydantic(list, metaclass=ModelMetaclass):
@classmethod @classmethod
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.__model_fields__.keys(): for field in one.Meta.model_fields.keys():
if isinstance(getattr(one, field), list) and not isinstance( if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), ormar.Model getattr(one, field), ormar.Model
): ):
setattr(other, field, getattr(one, field) + getattr(other, field)) setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), ormar.Model): elif isinstance(getattr(one, field), ormar.Model):

View File

@ -1,12 +1,15 @@
import copy from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type
import databases
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig, create_model from pydantic import BaseConfig, create_model, Extra
from pydantic.fields import ModelField from pydantic.fields import ModelField, FieldInfo
from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar import ForeignKey, ModelDefinitionError # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset import QuerySet
from ormar.relations import RelationshipManager from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -15,6 +18,17 @@ if TYPE_CHECKING: # pragma no cover
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()
class ModelMeta:
tablename: str
table: sqlalchemy.Table
metadata: sqlalchemy.MetaData
database: databases.Database
columns: List[sqlalchemy.Column]
pkname: str
model_fields: Dict[str, Union[BaseField, ForeignKey]]
_orm_relationship_manager: RelationshipManager
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
pydantic_fields = { pydantic_fields = {
field_name: ( field_name: (
@ -29,9 +43,9 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
child_relation_name = ( child_relation_name = (
field.to.get_name(title=True) field.to.get_name(title=True)
+ "_" + "_"
+ (field.related_name or (name.lower() + "s")) + (field.related_name or (name.lower() + "s"))
) )
reverse_name = child_relation_name reverse_name = child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name() relation_name = name.lower().title() + "_" + field.to.get_name()
@ -41,104 +55,125 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) ->
def expand_reverse_relationships(model: Type["Model"]) -> None: def expand_reverse_relationships(model: Type["Model"]) -> None:
for model_field in model.__model_fields__.values(): for model_field in model.Meta.model_fields.values():
if isinstance(model_field, ForeignKey): if issubclass(model_field, ForeignKeyField):
child_model_name = model_field.related_name or model.get_name() + "s" child_model_name = model_field.related_name or model.get_name() + "s"
parent_model = model_field.to parent_model = model_field.to
child = model child = model
if ( if (
child_model_name not in parent_model.__fields__ child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__ and child.get_name() not in parent_model.__fields__
): ):
register_reverse_model_fields(parent_model, child, child_model_name) register_reverse_model_fields(parent_model, child, child_model_name)
def register_reverse_model_fields( def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str model: Type["Model"], child: Type["Model"], child_model_name: str
) -> None: ) -> None:
model.__fields__[child_model_name] = ModelField( # model.__fields__[child_model_name] = ModelField(
name=child_model_name, # name=child_model_name,
type_=Optional[child.__pydantic_model__], # type_=Optional[Union[List[child], child]],
model_config=child.__pydantic_model__.__config__, # model_config=child.__config__,
class_validators=child.__pydantic_model__.__validators__, # class_validators=child.__validators__,
) # )
model.__model_fields__[child_model_name] = ForeignKey( model.Meta.model_fields[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True child, name=child_model_name, virtual=True
) )
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
columns = [] columns = []
pkname = None pkname = None
model_fields = { model_fields = {
field_name: field field_name: field
for field_name, field in object_dict.items() for field_name, field in object_dict['__annotations__'].items()
if isinstance(field, BaseField) if issubclass(field, BaseField)
} }
for field_name, field in model_fields.items(): for field_name, field in model_fields.items():
if field.primary_key: if field.primary_key:
if pkname is not None: if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.") raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only:
raise ModelDefinitionError('Primary key column cannot be pydantic only')
pkname = field_name pkname = field_name
if not field.pydantic_only: if not field.pydantic_only:
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
if isinstance(field, ForeignKey): if issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field, name) register_relation_on_build(table_name, field, name)
return pkname, columns, model_fields return pkname, columns, model_fields
def populate_pydantic_default_values(attrs: Dict) -> Dict:
for field, type_ in attrs['__annotations__'].items():
if issubclass(type_, BaseField):
if type_.name is None:
type_.name = field
def_value = type_.default_value()
curr_def_value = attrs.get(field, 'NONE')
if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo):
attrs[field] = def_value
elif curr_def_value == 'NONE' and type_.nullable:
attrs[field] = FieldInfo(default=None)
return attrs
def get_pydantic_base_orm_config() -> Type[BaseConfig]: def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class Config(BaseConfig): class Config(BaseConfig):
orm_mode = True orm_mode = True
arbitrary_types_allowed = True
# extra = Extra.allow
return Config return Config
class ModelMetaclass(type): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
attrs['Config'] = get_pydantic_base_orm_config()
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
if attrs.get("__abstract__"): if hasattr(new_model, 'Meta'):
return new_model
tablename = attrs.get("__tablename__", name.lower() + "s") if attrs.get("__abstract__"):
attrs["__tablename__"] = tablename return new_model
metadata = attrs["__metadata__"]
# sqlalchemy table creation attrs = populate_pydantic_default_values(attrs)
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
name, attrs, tablename
)
attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns)
attrs["__columns__"] = columns
attrs["__pkname__"] = pkname
if not pkname: tablename = name.lower() + "s"
raise ModelDefinitionError("Table has to have a primary key.") new_model.Meta.tablename = new_model.Meta.tablename or tablename
# pydantic model creation # sqlalchemy table creation
pydantic_fields = parse_pydantic_field_from_model_fields(attrs) pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
pydantic_model = create_model( name, attrs, new_model.Meta.tablename
name, __config__=get_pydantic_base_orm_config(), **pydantic_fields )
) new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns)
attrs["__pydantic_fields__"] = pydantic_fields new_model.Meta.columns = columns
attrs["__pydantic_model__"] = pydantic_model new_model.Meta.pkname = pkname
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
attrs["__model_fields__"] = model_fields if not pkname:
attrs["_orm_relationship_manager"] = relationship_manager breakpoint()
raise ModelDefinitionError("Table has to have a primary key.")
new_model = super().__new__( # type: ignore # pydantic model creation
mcs, name, bases, attrs new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
) new_model.Meta.pydantic_model = create_model(
name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields
)
expand_reverse_relationships(new_model) new_model.Meta.model_fields = model_fields
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
expand_reverse_relationships(new_model)
new_model.Meta._orm_relationship_manager = relationship_manager
new_model.objects = QuerySet(new_model)
# breakpoint()
return new_model return new_model

View File

@ -7,39 +7,39 @@ from ormar.models import FakePydantic # noqa I100
class Model(FakePydantic): class Model(FakePydantic):
__abstract__ = True __abstract__ = False
objects = ormar.queryset.QuerySet() # objects = ormar.queryset.QuerySet()
@classmethod @classmethod
def from_row( def from_row(
cls, cls,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
select_related: List = None, select_related: List = None,
previous_table: str = None, previous_table: str = None,
) -> "Model": ) -> "Model":
item = {} item = {}
select_related = select_related or [] select_related = select_related or []
table_prefix = cls._orm_relationship_manager.resolve_relation_join( table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join(
previous_table, cls.__table__.name previous_table, cls.Meta.table.name
) )
previous_table = cls.__table__.name previous_table = cls.Meta.table.name
for related in select_related: for related in select_related:
if "__" in related: if "__" in related:
first_part, remainder = related.split("__", 1) first_part, remainder = related.split("__", 1)
model_cls = cls.__model_fields__[first_part].to model_cls = cls.Meta.model_fields[first_part].to
child = model_cls.from_row( child = model_cls.from_row(
row, select_related=[remainder], previous_table=previous_table row, select_related=[remainder], previous_table=previous_table
) )
item[first_part] = child item[first_part] = child
else: else:
model_cls = cls.__model_fields__[related].to model_cls = cls.Meta.model_fields[related].to
child = model_cls.from_row(row, previous_table=previous_table) child = model_cls.from_row(row, previous_table=previous_table)
item[related] = child item[related] = child
for column in cls.__table__.columns: for column in cls.Meta.table.columns:
if column.name not in item: if column.name not in item:
item[column.name] = row[ item[column.name] = row[
f'{table_prefix + "_" if table_prefix else ""}{column.name}' f'{table_prefix + "_" if table_prefix else ""}{column.name}'
@ -47,22 +47,14 @@ class Model(FakePydantic):
return cls(**item) return cls(**item)
@property
def pk(self) -> str:
return getattr(self.values, self.__pkname__)
@pk.setter
def pk(self, value: Any) -> None:
setattr(self.values, self.__pkname__, value)
async def save(self) -> "Model": async def save(self) -> "Model":
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
if self.__model_fields__.get(self.__pkname__).autoincrement: if self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
self_fields.pop(self.__pkname__, None) self_fields.pop(self.Meta.pkname, None)
expr = self.__table__.insert() expr = self.Meta.table.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.__database__.execute(expr) item_id = await self.Meta.database.execute(expr)
self.pk = item_id setattr(self, self.Meta.pkname, item_id)
return self return self
async def update(self, **kwargs: Any) -> int: async def update(self, **kwargs: Any) -> int:
@ -71,23 +63,23 @@ class Model(FakePydantic):
self.from_dict(new_values) self.from_dict(new_values)
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
self_fields.pop(self.__pkname__) self_fields.pop(self.Meta.pkname)
expr = ( expr = (
self.__table__.update() self.Meta.table.update()
.values(**self_fields) .values(**self_fields)
.where(self.pk_column == getattr(self, self.__pkname__)) .where(self.pk_column == getattr(self, self.Meta.pkname))
) )
result = await self.__database__.execute(expr) result = await self.Meta.database.execute(expr)
return result return result
async def delete(self) -> int: async def delete(self) -> int:
expr = self.__table__.delete() expr = self.Meta.table.delete()
expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname)))
result = await self.__database__.execute(expr) result = await self.Meta.database.execute(expr)
return result return result
async def load(self) -> "Model": async def load(self) -> "Model":
expr = self.__table__.select().where(self.pk_column == self.pk) expr = self.Meta.table.select().where(self.pk_column == self.pk)
row = await self.__database__.fetch_one(expr) row = await self.Meta.database.fetch_one(expr)
self.from_dict(dict(row)) self.from_dict(dict(row))
return self return self

View File

@ -32,7 +32,7 @@ class QueryClause:
self.filter_clauses = filter_clauses self.filter_clauses = filter_clauses
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.__table__ self.table = self.model_cls.Meta.table
def filter( # noqa: A003 def filter( # noqa: A003
self, **kwargs: Any self, **kwargs: Any
@ -41,7 +41,7 @@ class QueryClause:
select_related = list(self._select_related) select_related = list(self._select_related)
if kwargs.get("pk"): if kwargs.get("pk"):
pk_name = self.model_cls.__pkname__ pk_name = self.model_cls.Meta.pkname
kwargs[pk_name] = kwargs.pop("pk") kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items(): for key, value in kwargs.items():
@ -65,8 +65,8 @@ class QueryClause:
related_parts, select_related related_parts, select_related
) )
table = model_cls.__table__ table = model_cls.Meta.table
column = model_cls.__table__.columns[field_name] column = model_cls.Meta.table.columns[field_name]
else: else:
op = "exact" op = "exact"
@ -106,12 +106,12 @@ class QueryClause:
# Walk the relationships to the actual model class # Walk the relationships to the actual model class
# against which the comparison is being made. # against which the comparison is being made.
previous_table = model_cls.__tablename__ previous_table = model_cls.Meta.tablename
for part in related_parts: for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__ current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
manager = model_cls._orm_relationship_manager manager = model_cls.Meta._orm_relationship_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table) table_prefix = manager.resolve_relation_join(previous_table, current_table)
model_cls = model_cls.__model_fields__[part].to model_cls = model_cls.Meta.model_fields[part].to
previous_table = current_table previous_table = current_table
return select_related, table_prefix, model_cls return select_related, table_prefix, model_cls
@ -128,7 +128,7 @@ class QueryClause:
clause_text = str( clause_text = str(
clause.compile( clause.compile(
dialect=self.model_cls.__database__._backend._dialect, dialect=self.model_cls.Meta.database._backend._dialect,
compile_kwargs={"literal_binds": True}, compile_kwargs={"literal_binds": True},
) )
) )

View File

@ -4,8 +4,8 @@ import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar import ForeignKey
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -20,12 +20,12 @@ class JoinParameters(NamedTuple):
class Query: class Query:
def __init__( def __init__(
self, self,
model_cls: Type["Model"], model_cls: Type["Model"],
filter_clauses: List, filter_clauses: List,
select_related: List, select_related: List,
limit_count: int, limit_count: int,
offset: int, offset: int,
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
@ -34,7 +34,7 @@ class Query:
self.filter_clauses = filter_clauses self.filter_clauses = filter_clauses
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.__table__ self.table = self.model_cls.Meta.table
self.auto_related = [] self.auto_related = []
self.used_aliases = [] self.used_aliases = []
@ -46,16 +46,16 @@ class Query:
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self.columns = list(self.table.columns) self.columns = list(self.table.columns)
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
self.select_from = self.table self.select_from = self.table
for key in self.model_cls.__model_fields__: for key in self.model_cls.Meta.model_fields:
if ( if (
not self.model_cls.__model_fields__[key].nullable not self.model_cls.Meta.model_fields[key].nullable
and isinstance( and isinstance(
self.model_cls.__model_fields__[key], ormar.fields.ForeignKey, self.model_cls.Meta.model_fields[key], ForeignKeyField,
) )
and key not in self._select_related and key not in self._select_related
): ):
self._select_related = [key] + self._select_related self._select_related = [key] + self._select_related
@ -79,7 +79,7 @@ class Query:
expr = self._apply_expression_modifiers(expr) expr = self._apply_expression_modifiers(expr)
# print(expr.compile(compile_kwargs={"literal_binds": True})) print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters() self._reset_query_parameters()
return expr, self._select_related return expr, self._select_related
@ -97,12 +97,12 @@ class Query:
@staticmethod @staticmethod
def _field_is_a_foreign_key_and_no_circular_reference( def _field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str field: BaseField, field_name: str, rel_part: str
) -> bool: ) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part return issubclass(field, ForeignKeyField) and field_name not in rel_part
def _field_qualifies_to_deeper_search( def _field_qualifies_to_deeper_search(
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
) -> bool: ) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1]) prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any( partial_match = any(
@ -112,39 +112,39 @@ class Query:
[x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]
) )
return ( return (
(field.virtual and parent_virtual) (field.virtual and parent_virtual)
or (partial_match and not already_checked) or (partial_match and not already_checked)
) or not nested ) or not nested
def on_clause( def on_clause(
self, previous_alias: str, alias: str, from_clause: str, to_clause: str, self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text: ) -> text:
left_part = f"{alias}_{to_clause}" left_part = f"{alias}_{to_clause}"
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}") return text(f"{left_part}={right_part}")
def _build_join_parameters( def _build_join_parameters(
self, part: str, join_params: JoinParameters self, part: str, join_params: JoinParameters
) -> JoinParameters: ) -> JoinParameters:
model_cls = join_params.model_cls.__model_fields__[part].to model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.__table__.name to_table = model_cls.Meta.table.name
alias = model_cls._orm_relationship_manager.resolve_relation_join( alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join(
join_params.from_table, to_table join_params.from_table, to_table
) )
if alias not in self.used_aliases: if alias not in self.used_aliases:
if join_params.prev_model.__model_fields__[part].virtual: if join_params.prev_model.Meta.model_fields[part].virtual:
to_key = next( to_key = next(
( (
v v
for k, v in model_cls.__model_fields__.items() for k, v in model_cls.Meta.model_fields.items()
if isinstance(v, ForeignKey) and v.to == join_params.prev_model if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model
), ),
None, None,
).name ).name
from_key = model_cls.__pkname__ from_key = model_cls.Meta.pkname
else: else:
to_key = model_cls.__pkname__ to_key = model_cls.Meta.pkname
from_key = part from_key = part
on_clause = self.on_clause( on_clause = self.on_clause(
@ -157,8 +157,8 @@ class Query:
self.select_from = sqlalchemy.sql.outerjoin( self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause self.select_from, target_table, on_clause
) )
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}"))
self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) self.columns.extend(self.prefixed_columns(alias, model_cls.Meta.table))
self.used_aliases.append(alias) self.used_aliases.append(alias)
previous_alias = alias previous_alias = alias
@ -167,24 +167,28 @@ class Query:
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _extract_auto_required_relations( def _extract_auto_required_relations(
self, self,
prev_model: Type["Model"], prev_model: Type["Model"],
rel_part: str = "", rel_part: str = "",
nested: bool = False, nested: bool = False,
parent_virtual: bool = False, parent_virtual: bool = False,
) -> None: ) -> None:
for field_name, field in prev_model.__model_fields__.items(): for field_name, field in prev_model.Meta.model_fields.items():
if self._field_is_a_foreign_key_and_no_circular_reference( if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part field, field_name, rel_part
): ):
rel_part = field_name if not rel_part else rel_part + "__" + field_name rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable: if not field.nullable:
print('add', rel_part, field)
if rel_part not in self._select_related: if rel_part not in self._select_related:
self.auto_related.append("__".join(rel_part.split("__")[:-1])) new_related = "__".join(rel_part.split("__")[:-1]) if len(
rel_part.split("__")) > 1 else rel_part
self.auto_related.append(new_related)
rel_part = "" rel_part = ""
elif self._field_qualifies_to_deeper_search( elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part field, parent_virtual, nested, rel_part
): ):
print('deeper', rel_part, field, field.to)
self._extract_auto_required_relations( self._extract_auto_required_relations(
prev_model=field.to, prev_model=field.to,
rel_part=rel_part, rel_part=rel_part,
@ -204,7 +208,7 @@ class Query:
self._select_related = new_joins + self.auto_related self._select_related = new_joins + self.auto_related
def _apply_expression_modifiers( def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select: ) -> sqlalchemy.sql.select:
if self.filter_clauses: if self.filter_clauses:
if len(self.filter_clauses) == 1: if len(self.filter_clauses) == 1:

View File

@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover
class QuerySet: class QuerySet:
def __init__( def __init__(
self, self,
model_cls: Type["Model"] = None, model_cls: Type["Model"] = None,
filter_clauses: List = None, filter_clauses: List = None,
select_related: List = None, select_related: List = None,
limit_count: int = None, limit_count: int = None,
offset: int = None, offset: int = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -33,11 +33,11 @@ class QuerySet:
@property @property
def database(self) -> databases.Database: def database(self) -> databases.Database:
return self.model_cls.__database__ return self.model_cls.Meta.database
@property @property
def table(self) -> sqlalchemy.Table: def table(self) -> sqlalchemy.Table:
return self.model_cls.__table__ return self.model_cls.Meta.table
def build_select_expression(self) -> sqlalchemy.sql.select: def build_select_expression(self) -> sqlalchemy.sql.select:
qry = Query( qry = Query(
@ -148,12 +148,12 @@ class QuerySet:
new_kwargs = dict(**kwargs) new_kwargs = dict(**kwargs)
# Remove primary key when None to prevent not null constraint in postgresql. # Remove primary key when None to prevent not null constraint in postgresql.
pkname = self.model_cls.__pkname__ pkname = self.model_cls.Meta.pkname
pk = self.model_cls.__model_fields__[pkname] pk = self.model_cls.Meta.model_fields[pkname]
if ( if (
pkname in new_kwargs pkname in new_kwargs
and new_kwargs.get(pkname) is None and new_kwargs.get(pkname) is None
and (pk.nullable or pk.autoincrement) and (pk.nullable or pk.autoincrement)
): ):
del new_kwargs[pkname] del new_kwargs[pkname]
@ -163,11 +163,11 @@ class QuerySet:
if isinstance(new_kwargs.get(field), ormar.Model): if isinstance(new_kwargs.get(field), ormar.Model):
new_kwargs[field] = getattr( new_kwargs[field] = getattr(
new_kwargs.get(field), new_kwargs.get(field),
self.model_cls.__model_fields__[field].to.__pkname__, self.model_cls.Meta.model_fields[field].to.Meta.pkname,
) )
else: else:
new_kwargs[field] = new_kwargs.get(field).get( new_kwargs[field] = new_kwargs.get(field).get(
self.model_cls.__model_fields__[field].to.__pkname__ self.model_cls.Meta.model_fields[field].to.Meta.pkname
) )
# Build the insert expression. # Build the insert expression.
@ -176,5 +176,6 @@ class QuerySet:
# Execute the insert, and return a new model instance. # Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs) instance = self.model_cls(**kwargs)
instance.pk = await self.database.execute(expr) pk = await self.database.execute(expr)
setattr(instance, self.model_cls.Meta.pkname, pk)
return instance return instance

View File

@ -6,6 +6,7 @@ from typing import List, TYPE_CHECKING, Union
from weakref import proxy from weakref import proxy
from ormar import ForeignKey from ormar import ForeignKey
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import FakePydantic, Model from ormar.models import FakePydantic, Model
@ -21,14 +22,14 @@ class RelationshipManager:
self._aliases = dict() self._aliases = dict()
def add_relation_type( def add_relation_type(
self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str
) -> None: ) -> None:
if relations_key not in self._relations: if relations_key not in self._relations:
self._relations[relations_key] = {"type": "primary"} self._relations[relations_key] = {"type": "primary"}
self._aliases[f"{table_name}_{field.to.__tablename__}"] = get_table_alias() self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
if reverse_key not in self._relations: if reverse_key not in self._relations:
self._relations[reverse_key] = {"type": "reverse"} self._relations[reverse_key] = {"type": "reverse"}
self._aliases[f"{field.to.__tablename__}_{table_name}"] = get_table_alias() self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
def deregister(self, model: "FakePydantic") -> None: def deregister(self, model: "FakePydantic") -> None:
for rel_type in self._relations.keys(): for rel_type in self._relations.keys():

View File

@ -16,17 +16,19 @@ def time():
class Example(ormar.Model): class Example(ormar.Model):
__tablename__ = "example" class Meta:
__metadata__ = metadata tablename = "example"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
created = ormar.DateTime(default=datetime.datetime.now) name: ormar.String(max_length=200, default='aaa')
created_day = ormar.Date(default=datetime.date.today) created: ormar.DateTime(default=datetime.datetime.now)
created_time = ormar.Time(default=time) created_day: ormar.Date(default=datetime.date.today)
description = ormar.Text(nullable=True) created_time: ormar.Time(default=time)
value = ormar.Float(nullable=True) description: ormar.Text(nullable=True)
data = ormar.JSON(default={}) value: ormar.Float(nullable=True)
data: ormar.JSON(default={})
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")

View File

@ -13,22 +13,24 @@ metadata = sqlalchemy.MetaData()
class Category(ormar.Model): class Category(ormar.Model):
__tablename__ = "categories" class Meta:
__metadata__ = metadata tablename = "categories"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Item(ormar.Model): class Item(ormar.Model):
__tablename__ = "items" class Meta:
__metadata__ = metadata tablename = "items"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
category = ormar.ForeignKey(Category, nullable=True) category: ormar.ForeignKey(Category, nullable=True)
@app.post("/items/", response_model=Item) @app.post("/items/", response_model=Item)

View File

@ -1,6 +1,7 @@
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from pydantic import ValidationError
import ormar import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
@ -11,62 +12,68 @@ metadata = sqlalchemy.MetaData()
class Album(ormar.Model): class Album(ormar.Model):
__tablename__ = "album" class Meta:
__metadata__ = metadata tablename = "albums"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Track(ormar.Model): class Track(ormar.Model):
__tablename__ = "track" class Meta:
__metadata__ = metadata tablename = "tracks"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
album = ormar.ForeignKey(Album) album: ormar.ForeignKey(Album)
title = ormar.String(length=100) title: ormar.String(max_length=100)
position = ormar.Integer() position: ormar.Integer()
class Cover(ormar.Model): class Cover(ormar.Model):
__tablename__ = "covers" class Meta:
__metadata__ = metadata tablename = "covers"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
album = ormar.ForeignKey(Album, related_name="cover_pictures") album: ormar.ForeignKey(Album, related_name="cover_pictures")
title = ormar.String(length=100) title: ormar.String(max_length=100)
class Organisation(ormar.Model): class Organisation(ormar.Model):
__tablename__ = "org" class Meta:
__metadata__ = metadata tablename = "org"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
ident = ormar.String(length=100) ident: ormar.String(max_length=100)
class Team(ormar.Model): class Team(ormar.Model):
__tablename__ = "team" class Meta:
__metadata__ = metadata tablename = "teams"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
org = ormar.ForeignKey(Organisation) org: ormar.ForeignKey(Organisation)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Member(ormar.Model): class Member(ormar.Model):
__tablename__ = "member" class Meta:
__metadata__ = metadata tablename = "members"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
team = ormar.ForeignKey(Team) team: ormar.ForeignKey(Team)
email = ormar.String(length=100) email: ormar.String(max_length=100)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -171,8 +178,8 @@ async def test_fk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(album__name="Fantasies") .filter(album__name="Fantasies")
.all() .all()
) )
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
@ -180,8 +187,8 @@ async def test_fk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(album__name__icontains="fan") .filter(album__name__icontains="fan")
.all() .all()
) )
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
@ -223,8 +230,8 @@ async def test_multiple_fk():
members = ( members = (
await Member.objects.select_related("team__org") await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd") .filter(team__org__ident="ACME Ltd")
.all() .all()
) )
assert len(members) == 4 assert len(members) == 4
for member in members: for member in members:
@ -243,8 +250,8 @@ async def test_pk_filter():
tracks = ( tracks = (
await Track.objects.select_related("album") await Track.objects.select_related("album")
.filter(position=2, album__name="Test") .filter(position=2, album__name="Test")
.all() .all()
) )
assert len(tracks) == 1 assert len(tracks) == 1

View File

@ -1,4 +1,5 @@
import datetime import datetime
import decimal
import pydantic import pydantic
import pytest import pytest
@ -12,19 +13,21 @@ metadata = sqlalchemy.MetaData()
class ExampleModel(Model): class ExampleModel(Model):
__tablename__ = "example" class Meta:
__metadata__ = metadata tablename = "example"
test = fields.Integer(primary_key=True) metadata = metadata
test_string = fields.String(length=250)
test_text = fields.Text(default="") test: fields.Integer(primary_key=True)
test_bool = fields.Boolean(nullable=False) test_string: fields.String(max_length=250)
test_float = fields.Float() test_text: fields.Text(default="")
test_datetime = fields.DateTime(default=datetime.datetime.now) test_bool: fields.Boolean(nullable=False)
test_date = fields.Date(default=datetime.date.today) test_float: fields.Float() = None
test_time = fields.Time(default=datetime.time) test_datetime: fields.DateTime(default=datetime.datetime.now)
test_json = fields.JSON(default={}) test_date: fields.Date(default=datetime.date.today)
test_bigint = fields.BigInteger(default=0) test_time: fields.Time(default=datetime.time)
test_decimal = fields.Decimal(length=10, precision=2) test_json: fields.JSON(default={})
test_bigint: fields.BigInteger(default=0)
test_decimal: fields.Decimal(scale=10, precision=2)
fields_to_check = [ fields_to_check = [
@ -41,15 +44,17 @@ fields_to_check = [
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example2" class Meta:
__metadata__ = metadata tablename = "example2"
test = fields.Integer(primary_key=True) metadata = metadata
test_string = fields.String(length=250)
test: fields.Integer(primary_key=True)
test_string: fields.String(max_length=250)
@pytest.fixture() @pytest.fixture()
def example(): def example():
return ExampleModel(pk=1, test_string="test", test_bool=True) return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5))
def test_not_nullable_field_is_required(): def test_not_nullable_field_is_required():
@ -83,60 +88,65 @@ def test_primary_key_access_and_setting(example):
def test_pydantic_model_is_created(example): def test_pydantic_model_is_created(example):
assert issubclass(example.values.__class__, pydantic.BaseModel) assert issubclass(example.__class__, pydantic.BaseModel)
assert all([field in example.values.__fields__ for field in fields_to_check]) assert all([field in example.__fields__ for field in fields_to_check])
assert example.values.test == 1 assert example.test == 1
def test_sqlalchemy_table_is_created(example): def test_sqlalchemy_table_is_created(example):
assert issubclass(example.__table__.__class__, sqlalchemy.Table) assert issubclass(example.Meta.table.__class__, sqlalchemy.Table)
assert all([field in example.__table__.columns for field in fields_to_check]) assert all([field in example.Meta.table.columns for field in fields_to_check])
def test_no_pk_in_model_definition(): def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example3" class Meta:
__metadata__ = metadata tablename = "example3"
test_string = fields.String(length=250) metadata = metadata
test_string: fields.String(max_length=250)
def test_two_pks_in_model_definition(): def test_two_pks_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example3" class Meta:
__metadata__ = metadata tablename = "example3"
id = fields.Integer(primary_key=True) metadata = metadata
test_string = fields.String(length=250, primary_key=True)
id: fields.Integer(primary_key=True)
test_string: fields.String(max_length=250, primary_key=True)
def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" class Meta:
__metadata__ = metadata tablename = "example4"
test = fields.Integer(primary_key=True, pydantic_only=True) metadata = metadata
test: fields.Integer(primary_key=True, pydantic_only=True)
def test_decimal_error_in_model_definition(): def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" class Meta:
__metadata__ = metadata tablename = "example5"
test = fields.Decimal(primary_key=True) metadata = metadata
test: fields.Decimal(primary_key=True)
def test_string_error_in_model_definition(): def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" class Meta:
__metadata__ = metadata tablename = "example6"
test = fields.String(primary_key=True) metadata = metadata
test: fields.String(primary_key=True)
def test_json_conversion_in_model(): def test_json_conversion_in_model():

View File

@ -1,4 +1,5 @@
import databases import databases
import pydantic
import pytest import pytest
import sqlalchemy import sqlalchemy
@ -11,23 +12,25 @@ metadata = sqlalchemy.MetaData()
class User(ormar.Model): class User(ormar.Model):
__tablename__ = "users" class Meta:
__metadata__ = metadata tablename = "users"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100, default='')
class Product(ormar.Model): class Product(ormar.Model):
__tablename__ = "product" class Meta:
__metadata__ = metadata tablename = "product"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
rating = ormar.Integer(minimum=1, maximum=5) rating: ormar.Integer(minimum=1, maximum=5)
in_stock = ormar.Boolean(default=False) in_stock: ormar.Boolean(default=False)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -39,12 +42,12 @@ def create_test_database():
def test_model_class(): def test_model_class():
assert list(User.__model_fields__.keys()) == ["id", "name"] assert list(User.Meta.model_fields.keys()) == ["id", "name"]
assert isinstance(User.__model_fields__["id"], ormar.Integer) assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt)
assert User.__model_fields__["id"].primary_key is True assert User.Meta.model_fields["id"].primary_key is True
assert isinstance(User.__model_fields__["name"], ormar.String) assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr)
assert User.__model_fields__["name"].length == 100 assert User.Meta.model_fields["name"].max_length == 100
assert isinstance(User.__table__, sqlalchemy.Table) assert isinstance(User.Meta.table, sqlalchemy.Table)
def test_model_pk(): def test_model_pk():

View File

@ -30,22 +30,24 @@ async def shutdown() -> None:
class Category(ormar.Model): class Category(ormar.Model):
__tablename__ = "categories" class Meta:
__metadata__ = metadata tablename = "categories"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Item(ormar.Model): class Item(ormar.Model):
__tablename__ = "items" class Meta:
__metadata__ = metadata tablename = "items"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
category = ormar.ForeignKey(Category, nullable=True) category: ormar.ForeignKey(Category, nullable=True)
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")

View File

@ -12,53 +12,58 @@ metadata = sqlalchemy.MetaData()
class Department(ormar.Model): class Department(ormar.Model):
__tablename__ = "departments" class Meta:
__metadata__ = metadata tablename = "departments"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True, autoincrement=False) id: ormar.Integer(primary_key=True, autoincrement=False)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class SchoolClass(ormar.Model): class SchoolClass(ormar.Model):
__tablename__ = "schoolclasses" class Meta:
__metadata__ = metadata tablename = "schoolclasses"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
department = ormar.ForeignKey(Department, nullable=False) department: ormar.ForeignKey(Department, nullable=False)
class Category(ormar.Model): class Category(ormar.Model):
__tablename__ = "categories" class Meta:
__metadata__ = metadata tablename = "categories"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
class Student(ormar.Model): class Student(ormar.Model):
__tablename__ = "students" class Meta:
__metadata__ = metadata tablename = "students"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
schoolclass = ormar.ForeignKey(SchoolClass) schoolclass: ormar.ForeignKey(SchoolClass)
category = ormar.ForeignKey(Category, nullable=True) category: ormar.ForeignKey(Category, nullable=True)
class Teacher(ormar.Model): class Teacher(ormar.Model):
__tablename__ = "teachers" class Meta:
__metadata__ = metadata tablename = "teachers"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(max_length=100)
schoolclass = ormar.ForeignKey(SchoolClass) schoolclass: ormar.ForeignKey(SchoolClass)
category = ormar.ForeignKey(Category, nullable=True) category: ormar.ForeignKey(Category, nullable=True)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")