diff --git a/.coverage b/.coverage index bf04874..a7dc6e2 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/fields/base.py b/ormar/fields/base.py index cef323f..67bf190 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, List, Optional, TYPE_CHECKING import sqlalchemy +from pydantic import Field from ormar import ModelDefinitionError # noqa I101 @@ -11,72 +12,55 @@ if TYPE_CHECKING: # pragma no cover class BaseField: __type__ = None - def __init__(self, **kwargs: Any) -> None: - self.name = None - self._populate_from_kwargs(kwargs) + column_type: sqlalchemy.Column + constraints: List = [] - def _populate_from_kwargs(self, kwargs: Dict) -> None: - self.primary_key = kwargs.pop("primary_key", False) - self.autoincrement = kwargs.pop( - "autoincrement", self.primary_key and self.__type__ == int - ) + primary_key: bool + autoincrement: bool + nullable: bool + index: bool + unique: bool + pydantic_only: bool - self.nullable = kwargs.pop("nullable", not self.primary_key) - self.default = kwargs.pop("default", None) - self.server_default = kwargs.pop("server_default", None) + default: Any + server_default: Any - self.index = kwargs.pop("index", None) - self.unique = kwargs.pop("unique", None) + @classmethod + def default_value(cls) -> Optional[Field]: + if cls.is_auto_primary_key(): + return Field(default=None) + if cls.has_default(): + default = cls.default if cls.default is not None else cls.server_default + if callable(default): + return Field(default_factory=default) + else: + return Field(default=default) + return None - self.pydantic_only = kwargs.pop("pydantic_only", False) - if self.pydantic_only and self.primary_key: - raise ModelDefinitionError("Primary key column cannot be pydantic only.") + @classmethod + def has_default(cls) -> bool: + return cls.default is not None or cls.server_default is not None - @property - def is_required(self) -> bool: - return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key - ) - - @property - def default_value(self) -> Any: - default = self.default - return default() if callable(default) else default - - @property - def has_default(self) -> bool: - return self.default is not None or self.server_default is not None - - @property - def is_auto_primary_key(self) -> bool: - if self.primary_key: - return self.autoincrement + @classmethod + def is_auto_primary_key(cls) -> bool: + if cls.primary_key: + return cls.autoincrement return False - def get_column(self, name: str = None) -> sqlalchemy.Column: - self.name = name - constraints = self.get_constraints() + @classmethod + def get_column(cls, name: str) -> sqlalchemy.Column: return sqlalchemy.Column( - self.name, - self.get_column_type(), - *constraints, - primary_key=self.primary_key, - autoincrement=self.autoincrement, - nullable=self.nullable, - index=self.index, - unique=self.unique, - default=self.default, - server_default=self.server_default, + name, + cls.column_type, + *cls.constraints, + primary_key=cls.primary_key, + nullable=cls.nullable and not cls.primary_key, + index=cls.index, + unique=cls.unique, + default=cls.default, + server_default=cls.server_default, ) - def get_column_type(self) -> sqlalchemy.types.TypeEngine: - raise NotImplementedError() # pragma: no cover - - def get_constraints(self) -> Optional[List]: - return [] - - def expand_relationship(self, value: Any, child: "Model") -> Any: + @classmethod + def expand_relationship(cls, value: Any, child: "Model") -> Any: return value - - def __repr__(self): # pragma no cover - return str(self.__dict__) diff --git a/ormar/fields/decorators.py b/ormar/fields/decorators.py deleted file mode 100644 index 842e864..0000000 --- a/ormar/fields/decorators.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, TYPE_CHECKING, Type - -from ormar import ModelDefinitionError - -if TYPE_CHECKING: # pragma no cover - from ormar.fields import BaseField - - -class RequiredParams: - def __init__(self, *args: str) -> None: - self._required = list(args) - - def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: - old_init = model_field_class.__init__ - model_field_class._old_init = old_init - - def __init__(instance: "BaseField", **kwargs: Any) -> None: - super(instance.__class__, instance).__init__(**kwargs) - for arg in self._required: - if arg not in kwargs: - raise ModelDefinitionError( - f"{instance.__class__.__name__} field requires parameter: {arg}" - ) - setattr(instance, arg, kwargs.pop(arg)) - - model_field_class.__init__ = __init__ - return model_field_class diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 77c2da7..9d052a7 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,7 +1,6 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy -from pydantic import BaseModel import ormar # noqa I101 from ormar.exceptions import RelationshipInstanceError @@ -13,87 +12,120 @@ if TYPE_CHECKING: # pragma no cover def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": init_dict = { - **{fk.__pkname__: pk or -1}, + **{fk.Meta.pkname: pk or -1, "__pk_only__": True}, **{ k: create_dummy_instance(v.to) - for k, v in fk.__model_fields__.items() - if isinstance(v, ForeignKey) and not v.nullable and not v.virtual + for k, v in fk.Meta.model_fields.items() + if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual }, } return fk(**init_dict) -class ForeignKey(BaseField): - def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, - ) -> None: - super().__init__(nullable=nullable, name=name) - self.virtual = virtual - self.related_name = related_name - self.to = to +def ForeignKey( + to: Type["Model"], + *, + name: str = None, + unique: bool = False, + nullable: bool = True, + related_name: str = None, + virtual: bool = False, +) -> Type[object]: + fk_string = to.Meta.tablename + "." + to.Meta.pkname + to_field = to.__fields__[to.Meta.pkname] + namespace = dict( + to=to, + name=name, + nullable=nullable, + constraints=[sqlalchemy.schema.ForeignKey(fk_string)], + unique=unique, + column_type=to_field.type_.column_type, + related_name=related_name, + virtual=virtual, + primary_key=False, + index=False, + pydantic_only=False, + default=None, + server_default=None, + ) - @property - def __type__(self) -> Type[BaseModel]: - return self.to.__pydantic_model__ + return type("ForeignKey", (ForeignKeyField, BaseField), namespace) - def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]: - fk_string = self.to.__tablename__ + "." + self.to.__pkname__ - return [sqlalchemy.schema.ForeignKey(fk_string)] - def get_column_type(self) -> sqlalchemy.Column: - to_column = self.to.__model_fields__[self.to.__pkname__] - return to_column.get_column_type() +class ForeignKeyField(BaseField): + to: Type["Model"] + related_name: str + virtual: bool - def _extract_model_from_sequence( - self, value: List, child: "Model" - ) -> Union["Model", List["Model"]]: - return [self.expand_relationship(val, child) for val in value] + @classmethod + def __get_validators__(cls) -> Callable: + yield cls.validate - def _register_existing_model(self, value: "Model", child: "Model") -> "Model": - self.register_relation(value, child) + @classmethod + def validate(cls, value: Any) -> Any: return value - def _construct_model_from_dict(self, value: dict, child: "Model") -> "Model": - model = self.to(**value) - self.register_relation(model, child) + # @property + # def __type__(self) -> Type[BaseModel]: + # return self.to.__pydantic_model__ + + # @classmethod + # def get_column_type(cls) -> sqlalchemy.Column: + # to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] + # return to_column.column_type + + @classmethod + def _extract_model_from_sequence( + cls, value: List, child: "Model" + ) -> Union["Model", List["Model"]]: + return [cls.expand_relationship(val, child) for val in value] + + @classmethod + def _register_existing_model(cls, value: "Model", child: "Model") -> "Model": + cls.register_relation(value, child) + return value + + @classmethod + def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model": + if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: + value["__pk_only__"] = True + model = cls.to(**value) + cls.register_relation(model, child) return model - def _construct_model_from_pk(self, value: Any, child: "Model") -> "Model": - if not isinstance(value, self.to.pk_type()): + @classmethod + def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model": + if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( - f"Relationship error - ForeignKey {self.to.__name__} " - f"is of type {self.to.pk_type()} " + f"Relationship error - ForeignKey {cls.to.__name__} " + f"is of type {cls.to.pk_type()} " f"while {type(value)} passed as a parameter." ) - model = create_dummy_instance(fk=self.to, pk=value) - self.register_relation(model, child) + model = create_dummy_instance(fk=cls.to, pk=value) + cls.register_relation(model, child) return model - def register_relation(self, model: "Model", child: "Model") -> None: - child_model_name = self.related_name or child.get_name() - model._orm_relationship_manager.add_relation( - model, child, child_model_name, virtual=self.virtual + @classmethod + def register_relation(cls, model: "Model", child: "Model") -> None: + child_model_name = cls.related_name or child.get_name() + model.Meta._orm_relationship_manager.add_relation( + model, child, child_model_name, virtual=cls.virtual ) + @classmethod def expand_relationship( - self, value: Any, child: "Model" + cls, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: - if value is None: return None constructors = { - f"{self.to.__name__}": self._register_existing_model, - "dict": self._construct_model_from_dict, - "list": self._extract_model_from_sequence, + f"{cls.to.__name__}": cls._register_existing_model, + "dict": cls._construct_model_from_dict, + "list": cls._extract_model_from_sequence, } model = constructors.get( - value.__class__.__name__, self._construct_model_from_pk + value.__class__.__name__, cls._construct_model_from_pk )(value, child) return model diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 4f9be11..dd81838 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,87 +1,269 @@ import datetime import decimal +from typing import Any, Optional, Type +import pydantic import sqlalchemy -from pydantic import Json +from ormar import ModelDefinitionError # noqa I101 from ormar.fields.base import BaseField # noqa I101 -from ormar.fields.decorators import RequiredParams -@RequiredParams("length") -class String(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.String(self.length) +def is_field_nullable( + nullable: Optional[bool], default: Any, server_default: Any +) -> bool: + if nullable is None: + return default is not None or server_default is not None + return nullable -class Integer(BaseField): - __type__ = int +class ModelFieldFactory: + _bases = None + _type = None - def get_column_type(self) -> sqlalchemy.Column: + def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: + cls.validate(**kwargs) + + default = kwargs.pop("default", None) + server_default = kwargs.pop("server_default", None) + nullable = kwargs.pop("nullable", None) + + namespace = dict( + __type__=cls._type, + name=kwargs.pop("name", None), + primary_key=kwargs.pop("primary_key", False), + default=default, + server_default=server_default, + nullable=is_field_nullable(nullable, default, server_default), + index=kwargs.pop("index", False), + unique=kwargs.pop("unique", False), + pydantic_only=kwargs.pop("pydantic_only", False), + autoincrement=kwargs.pop("autoincrement", False), + column_type=cls.get_column_type(**kwargs), + **kwargs + ) + return type(cls.__name__, cls._bases, namespace) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover + return None + + @classmethod + def validate(cls, **kwargs: Any) -> None: # pragma no cover + pass + + +class String(ModelFieldFactory): + _bases = (pydantic.ConstrainedStr, BaseField) + _type = str + + def __new__( + cls, + *, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + **kwargs: Any + ) -> Type[str]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.String(length=kwargs.get("max_length")) + + @classmethod + def validate(cls, **kwargs: Any) -> None: + max_length = kwargs.get("max_length", None) + if max_length is None or max_length <= 0: + raise ModelDefinitionError( + "Parameter max_length is required for field String" + ) + + +class Integer(ModelFieldFactory): + _bases = (pydantic.ConstrainedInt, BaseField) + _type = int + + def __new__( + cls, + *, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + **kwargs: Any + ) -> Type[int]: + autoincrement = kwargs.pop("autoincrement", None) + autoincrement = ( + autoincrement + if autoincrement is not None + else kwargs.get("primary_key", False) + ) + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.Integer() -class Text(BaseField): - __type__ = str +class Text(ModelFieldFactory): + _bases = (pydantic.ConstrainedStr, BaseField) + _type = str - def get_column_type(self) -> sqlalchemy.Column: + def __new__( + cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any + ) -> Type[str]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.Text() -class Float(BaseField): - __type__ = float +class Float(ModelFieldFactory): + _bases = (pydantic.ConstrainedFloat, BaseField) + _type = float - def get_column_type(self) -> sqlalchemy.Column: + def __new__( + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + **kwargs: Any + ) -> Type[int]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.Float() -class Boolean(BaseField): - __type__ = bool +class Boolean(ModelFieldFactory): + _bases = (int, BaseField) + _type = bool - def get_column_type(self) -> sqlalchemy.Column: + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.Boolean() -class DateTime(BaseField): - __type__ = datetime.datetime +class DateTime(ModelFieldFactory): + _bases = (datetime.datetime, BaseField) + _type = datetime.datetime - def get_column_type(self) -> sqlalchemy.Column: + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.DateTime() -class Date(BaseField): - __type__ = datetime.date +class Date(ModelFieldFactory): + _bases = (datetime.date, BaseField) + _type = datetime.date - def get_column_type(self) -> sqlalchemy.Column: + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.Date() -class Time(BaseField): - __type__ = datetime.time +class Time(ModelFieldFactory): + _bases = (datetime.time, BaseField) + _type = datetime.time - def get_column_type(self) -> sqlalchemy.Column: + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.Time() -class JSON(BaseField): - __type__ = Json +class JSON(ModelFieldFactory): + _bases = (pydantic.Json, BaseField) + _type = pydantic.Json - def get_column_type(self) -> sqlalchemy.Column: + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.JSON() -class BigInteger(BaseField): - __type__ = int +class BigInteger(Integer): + _bases = (pydantic.ConstrainedInt, BaseField) + _type = int - def get_column_type(self) -> sqlalchemy.Column: + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: return sqlalchemy.BigInteger() -@RequiredParams("length", "precision") -class Decimal(BaseField): - __type__ = decimal.Decimal +class Decimal(ModelFieldFactory): + _bases = (pydantic.ConstrainedDecimal, BaseField) + _type = decimal.Decimal - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DECIMAL(self.length, self.precision) + def __new__( + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + **kwargs: Any + ) -> Type[decimal.Decimal]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + precision = kwargs.get("precision") + scale = kwargs.get("scale") + return sqlalchemy.DECIMAL(precision=precision, scale=scale) + + @classmethod + def validate(cls, **kwargs: Any) -> None: + precision = kwargs.get("precision") + scale = kwargs.get("scale") + if precision is None or precision < 0 or scale is None or scale < 0: + raise ModelDefinitionError( + "Parameters scale and precision are required for field Decimal" + ) diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index d5109b4..3b74c9d 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -2,10 +2,11 @@ import inspect import json import uuid from typing import ( + AbstractSet, Any, - Callable, Dict, List, + Mapping, Optional, Set, TYPE_CHECKING, @@ -21,18 +22,26 @@ from pydantic import BaseModel import ormar # noqa I100 from ormar.fields import BaseField -from ormar.models.metaclass import ModelMetaclass +from ormar.fields.foreign_key import ForeignKeyField +from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model + IntStr = Union[int, str] + DictStrAny = Dict[str, Any] + AbstractSetIntStr = AbstractSet[IntStr] + MappingIntStrAny = Mapping[IntStr, Any] -class FakePydantic(list, metaclass=ModelMetaclass): + +class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): # FakePydantic inherits from list in order to be treated as # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts + __slots__ = ("_orm_id", "_orm_saved") __abstract__ = True + if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] __table__: sqlalchemy.Table @@ -43,63 +52,82 @@ class FakePydantic(list, metaclass=ModelMetaclass): __metadata__: sqlalchemy.MetaData __database__: databases.Database _orm_relationship_manager: RelationshipManager + Meta: ModelMeta + # noinspection PyMissingConstructor def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self._orm_id: str = uuid.uuid4().hex - self._orm_saved: bool = False - self.values: Optional[BaseModel] = None + object.__setattr__(self, "_orm_id", uuid.uuid4().hex) + object.__setattr__(self, "_orm_saved", False) + + pk_only = kwargs.pop("__pk_only__", False) if "pk" in kwargs: - kwargs[self.__pkname__] = kwargs.pop("pk") + kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs = { - k: self.__model_fields__[k].expand_relationship(v, self) + k: self._convert_json( + k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps" + ) for k, v in kwargs.items() } - self.values = self.__pydantic_model__(**kwargs) + + values, fields_set, validation_error = pydantic.validate_model(self, kwargs) + if validation_error and not pk_only: + raise validation_error + + object.__setattr__(self, "__dict__", values) + object.__setattr__(self, "__fields_set__", fields_set) + + # super().__init__(**kwargs) + # self.values = self.__pydantic_model__(**kwargs) def __del__(self) -> None: - self._orm_relationship_manager.deregister(self) + self.Meta._orm_relationship_manager.deregister(self) - def __setattr__(self, key: str, value: Any) -> None: - if key in self.__fields__: - value = self._convert_json(key, value, op="dumps") - value = self.__model_fields__[key].expand_relationship(value, self) - - relation_key = self.get_name(title=True) + "_" + key - if not self._orm_relationship_manager.contains(relation_key, self): - setattr(self.values, key, value) + def __setattr__(self, name: str, value: Any) -> None: + relation_key = self.get_name(title=True) + "_" + name + if name in self.__slots__: + object.__setattr__(self, name, value) + elif name == "pk": + object.__setattr__(self, self.Meta.pkname, value) + elif self.Meta._orm_relationship_manager.contains(relation_key, self): + self.Meta.model_fields[name].expand_relationship(value, self) else: - super().__setattr__(key, value) + value = ( + self._convert_json(name, value, "dumps") + if name in self.__fields__ + else value + ) + super().__setattr__(name, value) - def __getattribute__(self, key: str) -> Any: - if key != "__fields__" and key in self.__fields__: - relation_key = self.get_name(title=True) + "_" + key - if self._orm_relationship_manager.contains(relation_key, self): - return self._orm_relationship_manager.get(relation_key, self) + def __getattribute__(self, item: str) -> Any: + if item != "__fields__" and item in self.__fields__: + related = self._extract_related_model_instead_of_field(item) + if related: + return related + value = object.__getattribute__(self, item) + value = self._convert_json(item, value, "loads") + return value + return super().__getattribute__(item) - item = getattr(self.values, key, None) - item = self._convert_json(key, item, op="loads") - return item - return super().__getattribute__(key) + def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]: + return self._extract_related_model_instead_of_field(item) - def __eq__(self, other: "Model") -> bool: - return self.values.dict() == other.values.dict() + def _extract_related_model_instead_of_field( + self, item: str + ) -> Optional[Union["Model", List["Model"]]]: + relation_key = self.get_name(title=True) + "_" + item + if self.Meta._orm_relationship_manager.contains(relation_key, self): + return self.Meta._orm_relationship_manager.get(relation_key, self) def __same__(self, other: "Model") -> bool: if self.__class__ != other.__class__: # pragma no cover return False - return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + return ( + self._orm_id == other._orm_id + or self.__dict__ == other.__dict__ + or (self.pk == other.pk and self.pk is not None) ) - def __repr__(self) -> str: # pragma no cover - return self.values.__repr__() - - @classmethod - def __get_validators__(cls) -> Callable: # pragma no cover - yield cls.__pydantic_model__.validate - @classmethod def get_name(cls, title: bool = False, lower: bool = True) -> str: name = cls.__name__ @@ -109,28 +137,50 @@ class FakePydantic(list, metaclass=ModelMetaclass): name = name.title() return name + @property + def pk(self) -> Any: + return getattr(self, self.Meta.pkname) + @property def pk_column(self) -> sqlalchemy.Column: - return self.__table__.primary_key.columns.values()[0] + return self.Meta.table.primary_key.columns.values()[0] @classmethod def pk_type(cls) -> Any: - return cls.__model_fields__[cls.__pkname__].__type__ + return cls.Meta.model_fields[cls.Meta.pkname].__type__ - def dict(self, nested=False) -> Dict: # noqa: A003 - dict_instance = self.values.dict() + def dict( # noqa A003 + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False + ) -> "DictStrAny": # noqa: A003' + dict_instance = super().dict( + include=include, + exclude=self._exclude_related_names_not_required(nested), + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) for field in self._extract_related_names(): nested_model = getattr(self, field) - if self.__model_fields__[field].virtual and nested: + + if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list) and not isinstance( nested_model, ormar.Model ): dict_instance[field] = [x.dict(nested=True) for x in nested_model] - else: - dict_instance[field] = ( - nested_model.dict(nested=True) if nested_model is not None else {} - ) + elif nested_model is not None: + dict_instance[field] = nested_model.dict(nested=True) return dict_instance def from_dict(self, value_dict: Dict) -> None: @@ -155,7 +205,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): return value def _is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.__model_fields__.get(column_name).__type__ == pydantic.Json + return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json def _extract_own_model_fields(self) -> Dict: related_names = self._extract_related_names() @@ -165,9 +215,21 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def _extract_related_names(cls) -> Set: related_names = set() - for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass(field, ForeignKeyField): + related_names.add(name) + return related_names + + @classmethod + def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: + if nested: + return cls._extract_related_names() + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if ( + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and field.nullable ): related_names.add(name) return related_names @@ -175,13 +237,12 @@ class FakePydantic(list, metaclass=ModelMetaclass): def _extract_model_db_fields(self) -> Dict: self_fields = self._extract_own_model_fields() self_fields = { - k: v for k, v in self_fields.items() if k in self.__table__.columns + k: v for k, v in self_fields.items() if k in self.Meta.table.columns } for field in self._extract_related_names(): + target_pk_name = self.Meta.model_fields[field].to.Meta.pkname if getattr(self, field) is not None: - self_fields[field] = getattr( - getattr(self, field), self.__model_fields__[field].to.__pkname__ - ) + self_fields[field] = getattr(getattr(self, field), target_pk_name) return self_fields @classmethod @@ -196,18 +257,19 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": - for field in one.__model_fields__.keys(): - if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), ormar.Model + for field in one.Meta.model_fields.keys(): + current_field = getattr(one, field) + if isinstance(current_field, list) and not isinstance( + current_field, ormar.Model ): - setattr(other, field, getattr(one, field) + getattr(other, field)) - elif isinstance(getattr(one, field), ormar.Model): - if getattr(one, field).pk == getattr(other, field).pk: - setattr( - other, - field, - cls.merge_two_instances( - getattr(one, field), getattr(other, field) - ), - ) + setattr(other, field, current_field + getattr(other, field)) + elif ( + isinstance(current_field, ormar.Model) + and current_field.pk == getattr(other, field).pk + ): + setattr( + other, + field, + cls.merge_two_instances(current_field, getattr(other, field)), + ) return other diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 7ea2fe7..0d621bd 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,12 +1,15 @@ -import copy -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union +import databases +import pydantic import sqlalchemy -from pydantic import BaseConfig, create_model -from pydantic.fields import ModelField +from pydantic import BaseConfig +from pydantic.fields import FieldInfo from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField +from ormar.queryset import QuerySet from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover @@ -15,16 +18,15 @@ if TYPE_CHECKING: # pragma no cover relationship_manager = RelationshipManager() -def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: - pydantic_fields = { - field_name: ( - base_field.__type__, - ... if base_field.is_required else base_field.default_value, - ) - for field_name, base_field in object_dict.items() - if isinstance(base_field, BaseField) - } - return pydantic_fields +class ModelMeta: + tablename: str + table: sqlalchemy.Table + metadata: sqlalchemy.MetaData + database: databases.Database + columns: List[sqlalchemy.Column] + pkname: str + model_fields: Dict[str, Union[BaseField, ForeignKey]] + _orm_relationship_manager: RelationshipManager def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: @@ -41,8 +43,8 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> def expand_reverse_relationships(model: Type["Model"]) -> None: - for model_field in model.__model_fields__.values(): - if isinstance(model_field, ForeignKey): + for model_field in model.Meta.model_fields.values(): + if issubclass(model_field, ForeignKeyField): child_model_name = model_field.related_name or model.get_name() + "s" parent_model = model_field.to child = model @@ -56,13 +58,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( + model.Meta.model_fields[child_model_name] = ForeignKey( child, name=child_model_name, virtual=True ) @@ -74,71 +70,96 @@ def sqlalchemy_columns_from_model_fields( pkname = None model_fields = { field_name: field - for field_name, field in object_dict.items() - if isinstance(field, BaseField) + for field_name, field in object_dict["__annotations__"].items() + if issubclass(field, BaseField) } for field_name, field in model_fields.items(): if field.primary_key: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") + if field.pydantic_only: + raise ModelDefinitionError("Primary key column cannot be pydantic only") pkname = field_name if not field.pydantic_only: columns.append(field.get_column(field_name)) - if isinstance(field, ForeignKey): + if issubclass(field, ForeignKeyField): register_relation_on_build(table_name, field, name) return pkname, columns, model_fields +def populate_pydantic_default_values(attrs: Dict) -> Dict: + for field, type_ in attrs["__annotations__"].items(): + if issubclass(type_, BaseField): + if type_.name is None: + type_.name = field + def_value = type_.default_value() + curr_def_value = attrs.get(field, "NONE") + if curr_def_value == "NONE" and isinstance(def_value, FieldInfo): + attrs[field] = def_value + elif curr_def_value == "NONE" and type_.nullable: + attrs[field] = FieldInfo(default=None) + return attrs + + def get_pydantic_base_orm_config() -> Type[BaseConfig]: class Config(BaseConfig): orm_mode = True + arbitrary_types_allowed = True + # extra = Extra.allow return Config -class ModelMetaclass(type): +class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: + + attrs["Config"] = get_pydantic_base_orm_config() new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) - if attrs.get("__abstract__"): - return new_model + if hasattr(new_model, "Meta"): - tablename = attrs.get("__tablename__", name.lower() + "s") - attrs["__tablename__"] = tablename - metadata = attrs["__metadata__"] + annotations = attrs.get("__annotations__") or new_model.__annotations__ + attrs["__annotations__"] = annotations + attrs = populate_pydantic_default_values(attrs) - # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, tablename - ) - attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) - attrs["__columns__"] = columns - attrs["__pkname__"] = pkname + tablename = name.lower() + "s" + new_model.Meta.tablename = new_model.Meta.tablename or tablename - if not pkname: - raise ModelDefinitionError("Table has to have a primary key.") + # sqlalchemy table creation - # pydantic model creation - pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model( - name, __config__=get_pydantic_base_orm_config(), **pydantic_fields - ) - attrs["__pydantic_fields__"] = pydantic_fields - attrs["__pydantic_model__"] = pydantic_model - attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) - attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) - attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, new_model.Meta.tablename + ) - attrs["__model_fields__"] = model_fields - attrs["_orm_relationship_manager"] = relationship_manager + if hasattr(new_model.Meta, "model_fields") and not pkname: + model_fields = new_model.Meta.model_fields + for fieldname, field in new_model.Meta.model_fields.items(): + if field.primary_key: + pkname = fieldname + columns = new_model.Meta.table.columns - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) + if not hasattr(new_model.Meta, "table"): + new_model.Meta.table = sqlalchemy.Table( + new_model.Meta.tablename, new_model.Meta.metadata, *columns + ) - expand_reverse_relationships(new_model) + new_model.Meta.columns = columns + new_model.Meta.pkname = pkname + if not pkname: + raise ModelDefinitionError("Table has to have a primary key.") + + new_model.Meta.model_fields = model_fields + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + expand_reverse_relationships(new_model) + + new_model.Meta._orm_relationship_manager = relationship_manager + new_model.objects = QuerySet(new_model) + + # breakpoint() return new_model diff --git a/ormar/models/model.py b/ormar/models/model.py index b16d3e7..5cb9ed0 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -7,9 +7,9 @@ from ormar.models import FakePydantic # noqa I100 class Model(FakePydantic): - __abstract__ = True + __abstract__ = False - objects = ormar.queryset.QuerySet() + # objects = ormar.queryset.QuerySet() @classmethod def from_row( @@ -22,24 +22,24 @@ class Model(FakePydantic): item = {} select_related = select_related or [] - table_prefix = cls._orm_relationship_manager.resolve_relation_join( - previous_table, cls.__table__.name + table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join( + previous_table, cls.Meta.table.name ) - previous_table = cls.__table__.name + previous_table = cls.Meta.table.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) - model_cls = cls.__model_fields__[first_part].to + model_cls = cls.Meta.model_fields[first_part].to child = model_cls.from_row( row, select_related=[remainder], previous_table=previous_table ) item[first_part] = child else: - model_cls = cls.__model_fields__[related].to + model_cls = cls.Meta.model_fields[related].to child = model_cls.from_row(row, previous_table=previous_table) item[related] = child - for column in cls.__table__.columns: + for column in cls.Meta.table.columns: if column.name not in item: item[column.name] = row[ f'{table_prefix + "_" if table_prefix else ""}{column.name}' @@ -47,22 +47,14 @@ class Model(FakePydantic): return cls(**item) - @property - def pk(self) -> str: - return getattr(self.values, self.__pkname__) - - @pk.setter - def pk(self, value: Any) -> None: - setattr(self.values, self.__pkname__, value) - async def save(self) -> "Model": self_fields = self._extract_model_db_fields() - if self.__model_fields__.get(self.__pkname__).autoincrement: - self_fields.pop(self.__pkname__, None) - expr = self.__table__.insert() + if self.Meta.model_fields.get(self.Meta.pkname).autoincrement: + self_fields.pop(self.Meta.pkname, None) + expr = self.Meta.table.insert() expr = expr.values(**self_fields) - item_id = await self.__database__.execute(expr) - self.pk = item_id + item_id = await self.Meta.database.execute(expr) + setattr(self, self.Meta.pkname, item_id) return self async def update(self, **kwargs: Any) -> int: @@ -71,23 +63,23 @@ class Model(FakePydantic): self.from_dict(new_values) self_fields = self._extract_model_db_fields() - self_fields.pop(self.__pkname__) + self_fields.pop(self.Meta.pkname) expr = ( - self.__table__.update() + self.Meta.table.update() .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) + .where(self.pk_column == getattr(self, self.Meta.pkname)) ) - result = await self.__database__.execute(expr) + result = await self.Meta.database.execute(expr) return result async def delete(self) -> int: - expr = self.__table__.delete() - expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) - result = await self.__database__.execute(expr) + expr = self.Meta.table.delete() + expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) + result = await self.Meta.database.execute(expr) return result async def load(self) -> "Model": - expr = self.__table__.select().where(self.pk_column == self.pk) - row = await self.__database__.fetch_one(expr) + expr = self.Meta.table.select().where(self.pk_column == self.pk) + row = await self.Meta.database.fetch_one(expr) self.from_dict(dict(row)) return self diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 3d5d14b..dc94e6f 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -32,7 +32,7 @@ class QueryClause: self.filter_clauses = filter_clauses self.model_cls = model_cls - self.table = self.model_cls.__table__ + self.table = self.model_cls.Meta.table def filter( # noqa: A003 self, **kwargs: Any @@ -41,7 +41,7 @@ class QueryClause: select_related = list(self._select_related) if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ + pk_name = self.model_cls.Meta.pkname kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): @@ -65,8 +65,8 @@ class QueryClause: related_parts, select_related ) - table = model_cls.__table__ - column = model_cls.__table__.columns[field_name] + table = model_cls.Meta.table + column = model_cls.Meta.table.columns[field_name] else: op = "exact" @@ -106,12 +106,12 @@ class QueryClause: # Walk the relationships to the actual model class # against which the comparison is being made. - previous_table = model_cls.__tablename__ + previous_table = model_cls.Meta.tablename for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - manager = model_cls._orm_relationship_manager + current_table = model_cls.Meta.model_fields[part].to.Meta.tablename + manager = model_cls.Meta._orm_relationship_manager table_prefix = manager.resolve_relation_join(previous_table, current_table) - model_cls = model_cls.__model_fields__[part].to + model_cls = model_cls.Meta.model_fields[part].to previous_table = current_table return select_related, table_prefix, model_cls @@ -128,7 +128,7 @@ class QueryClause: clause_text = str( clause.compile( - dialect=self.model_cls.__database__._backend._dialect, + dialect=self.model_cls.Meta.database._backend._dialect, compile_kwargs={"literal_binds": True}, ) ) diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 202b249..55e936a 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -4,8 +4,8 @@ import sqlalchemy from sqlalchemy import text import ormar # noqa I100 -from ormar import ForeignKey from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -34,7 +34,7 @@ class Query: self.filter_clauses = filter_clauses self.model_cls = model_cls - self.table = self.model_cls.__table__ + self.table = self.model_cls.Meta.table self.auto_related = [] self.used_aliases = [] @@ -46,18 +46,18 @@ class Query: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] + self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.select_from = self.table - for key in self.model_cls.__model_fields__: - if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], ormar.fields.ForeignKey, - ) - and key not in self._select_related - ): - self._select_related = [key] + self._select_related + # for key in self.model_cls.Meta.model_fields: + # if ( + # not self.model_cls.Meta.model_fields[key].nullable + # and isinstance( + # self.model_cls.Meta.model_fields[key], ForeignKeyField, + # ) + # and key not in self._select_related + # ): + # self._select_related = [key] + self._select_related start_params = JoinParameters( self.model_cls, "", self.table.name, self.model_cls @@ -97,12 +97,12 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: Type[BaseField], field_name: str, rel_part: str ) -> bool: - return isinstance(field, ForeignKey) and field_name not in rel_part + return issubclass(field, ForeignKeyField) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( @@ -126,25 +126,26 @@ class Query: def _build_join_parameters( self, part: str, join_params: JoinParameters ) -> JoinParameters: - model_cls = join_params.model_cls.__model_fields__[part].to - to_table = model_cls.__table__.name + model_cls = join_params.model_cls.Meta.model_fields[part].to + to_table = model_cls.Meta.table.name - alias = model_cls._orm_relationship_manager.resolve_relation_join( + alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join( join_params.from_table, to_table ) if alias not in self.used_aliases: - if join_params.prev_model.__model_fields__[part].virtual: + if join_params.prev_model.Meta.model_fields[part].virtual: to_key = next( ( v - for k, v in model_cls.__model_fields__.items() - if isinstance(v, ForeignKey) and v.to == join_params.prev_model + for k, v in model_cls.Meta.model_fields.items() + if issubclass(v, ForeignKeyField) + and v.to == join_params.prev_model ), None, ).name - from_key = model_cls.__pkname__ + from_key = model_cls.Meta.pkname else: - to_key = model_cls.__pkname__ + to_key = model_cls.Meta.pkname from_key = part on_clause = self.on_clause( @@ -157,8 +158,8 @@ class Query: self.select_from = sqlalchemy.sql.outerjoin( self.select_from, target_table, on_clause ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) - self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) + self.columns.extend(self.prefixed_columns(alias, model_cls.Meta.table)) self.used_aliases.append(alias) previous_alias = alias @@ -173,18 +174,24 @@ class Query: nested: bool = False, parent_virtual: bool = False, ) -> None: - for field_name, field in prev_model.__model_fields__.items(): + for field_name, field in prev_model.Meta.model_fields.items(): if self._field_is_a_foreign_key_and_no_circular_reference( field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: if rel_part not in self._select_related: - self.auto_related.append("__".join(rel_part.split("__")[:-1])) + new_related = ( + "__".join(rel_part.split("__")[:-1]) + if len(rel_part.split("__")) > 1 + else rel_part + ) + self.auto_related.append(new_related) rel_part = "" elif self._field_qualifies_to_deeper_search( field, parent_virtual, nested, rel_part ): + self._extract_auto_required_relations( prev_model=field.to, rel_part=rel_part, diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 65192b8..56d1ee2 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -33,11 +33,11 @@ class QuerySet: @property def database(self) -> databases.Database: - return self.model_cls.__database__ + return self.model_cls.Meta.database @property def table(self) -> sqlalchemy.Table: - return self.model_cls.__table__ + return self.model_cls.Meta.table def build_select_expression(self) -> sqlalchemy.sql.select: qry = Query( @@ -148,8 +148,8 @@ class QuerySet: new_kwargs = dict(**kwargs) # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.__model_fields__[pkname] + pkname = self.model_cls.Meta.pkname + pk = self.model_cls.Meta.model_fields[pkname] if ( pkname in new_kwargs and new_kwargs.get(pkname) is None @@ -163,11 +163,11 @@ class QuerySet: if isinstance(new_kwargs.get(field), ormar.Model): new_kwargs[field] = getattr( new_kwargs.get(field), - self.model_cls.__model_fields__[field].to.__pkname__, + self.model_cls.Meta.model_fields[field].to.Meta.pkname, ) else: new_kwargs[field] = new_kwargs.get(field).get( - self.model_cls.__model_fields__[field].to.__pkname__ + self.model_cls.Meta.model_fields[field].to.Meta.pkname ) # Build the insert expression. @@ -176,5 +176,6 @@ class QuerySet: # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) + pk = await self.database.execute(expr) + setattr(instance, self.model_cls.Meta.pkname, pk) return instance diff --git a/ormar/relations.py b/ormar/relations.py index 45c3ddd..008e566 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -5,7 +5,7 @@ from random import choices from typing import List, TYPE_CHECKING, Union from weakref import proxy -from ormar import ForeignKey +from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar.models import FakePydantic, Model @@ -21,14 +21,18 @@ class RelationshipManager: self._aliases = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + self, + relations_key: str, + reverse_key: str, + field: ForeignKeyField, + table_name: str, ) -> None: if relations_key not in self._relations: self._relations[relations_key] = {"type": "primary"} - self._aliases[f"{table_name}_{field.to.__tablename__}"] = get_table_alias() + self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() if reverse_key not in self._relations: self._relations[reverse_key] = {"type": "reverse"} - self._aliases[f"{field.to.__tablename__}_{table_name}"] = get_table_alias() + self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() def deregister(self, model: "FakePydantic") -> None: for rel_type in self._relations.keys(): diff --git a/scripts/publish.sh b/scripts/publish.sh old mode 100644 new mode 100755 diff --git a/tests/test_columns.py b/tests/test_columns.py index e75cb0a..c8c9d3b 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -16,17 +16,19 @@ def time(): class Example(ormar.Model): - __tablename__ = "example" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "example" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - created = ormar.DateTime(default=datetime.datetime.now) - created_day = ormar.Date(default=datetime.date.today) - created_time = ormar.Time(default=time) - description = ormar.Text(nullable=True) - value = ormar.Float(nullable=True) - data = ormar.JSON(default={}) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=200, default='aaa') + created: ormar.DateTime(default=datetime.datetime.now) + created_day: ormar.Date(default=datetime.date.today) + created_time: ormar.Time(default=time) + description: ormar.Text(nullable=True) + value: ormar.Float(nullable=True) + data: ormar.JSON(default={}) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 25b1310..f7f2625 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -13,22 +13,24 @@ metadata = sqlalchemy.MetaData() class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Item(ormar.Model): - __tablename__ = "items" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "items" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) @app.post("/items/", response_model=Item) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index bffa783..6463eb6 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,9 +1,11 @@ import databases import pytest import sqlalchemy +from pydantic import ValidationError import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError +from ormar.fields.foreign_key import ForeignKeyField from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -11,62 +13,68 @@ metadata = sqlalchemy.MetaData() class Album(ormar.Model): - __tablename__ = "album" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "albums" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Track(ormar.Model): - __tablename__ = "track" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "tracks" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album) - title = ormar.String(length=100) - position = ormar.Integer() + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album) + title: ormar.String(max_length=100) + position: ormar.Integer() class Cover(ormar.Model): - __tablename__ = "covers" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "covers" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album, related_name="cover_pictures") - title = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album, related_name="cover_pictures") + title: ormar.String(max_length=100) class Organisation(ormar.Model): - __tablename__ = "org" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "org" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - ident = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + ident: ormar.String(max_length=100) class Team(ormar.Model): - __tablename__ = "team" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "teams" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - org = ormar.ForeignKey(Organisation) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + org: ormar.ForeignKey(Organisation) + name: ormar.String(max_length=100) class Member(ormar.Model): - __tablename__ = "member" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "members" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - team = ormar.ForeignKey(Team) - email = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + team: ormar.ForeignKey(Team) + email: ormar.String(max_length=100) @pytest.fixture(autouse=True, scope="module") @@ -113,6 +121,7 @@ async def test_model_crud(): track = await Track.objects.get(title="The Bird") assert track.album.pk == album.pk + assert isinstance(track.album, ormar.Model) assert track.album.name is None await track.album.load() assert track.album.name == "Malibu" @@ -124,6 +133,8 @@ async def test_model_crud(): assert album1.pk == 1 assert album1.tracks is None + await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4) + @pytest.mark.asyncio async def test_select_related(): @@ -171,8 +182,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -180,8 +191,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -223,8 +234,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -243,8 +254,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index f6dc722..ebb0619 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -1,4 +1,5 @@ import datetime +import decimal import pydantic import pytest @@ -12,19 +13,21 @@ metadata = sqlalchemy.MetaData() class ExampleModel(Model): - __tablename__ = "example" - __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) - test_text = fields.Text(default="") - test_bool = fields.Boolean(nullable=False) - test_float = fields.Float() - test_datetime = fields.DateTime(default=datetime.datetime.now) - test_date = fields.Date(default=datetime.date.today) - test_time = fields.Time(default=datetime.time) - test_json = fields.JSON(default={}) - test_bigint = fields.BigInteger(default=0) - test_decimal = fields.Decimal(length=10, precision=2) + class Meta: + tablename = "example" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) + test_text: fields.Text(default="") + test_bool: fields.Boolean(nullable=False) + test_float: fields.Float() = None + test_datetime: fields.DateTime(default=datetime.datetime.now) + test_date: fields.Date(default=datetime.date.today) + test_time: fields.Time(default=datetime.time) + test_json: fields.JSON(default={}) + test_bigint: fields.BigInteger(default=0) + test_decimal: fields.Decimal(scale=10, precision=2) fields_to_check = [ @@ -41,15 +44,17 @@ fields_to_check = [ class ExampleModel2(Model): - __tablename__ = "example2" - __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) + class Meta: + tablename = "example2" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) @pytest.fixture() def example(): - return ExampleModel(pk=1, test_string="test", test_bool=True) + return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)) def test_not_nullable_field_is_required(): @@ -70,8 +75,18 @@ def test_model_attribute_access(example): example.test = 12 assert example.test == 12 - example.new_attr = 12 - assert "new_attr" in example.__dict__ + example._orm_saved = True + assert example._orm_saved + + +def test_model_attribute_json_access(example): + example.test_json = dict(aa=12) + assert example.test_json == dict(aa=12) + + +def test_non_existing_attr(example): + with pytest.raises(ValueError): + example.new_attr = 12 def test_primary_key_access_and_setting(example): @@ -83,60 +98,65 @@ def test_primary_key_access_and_setting(example): def test_pydantic_model_is_created(example): - assert issubclass(example.values.__class__, pydantic.BaseModel) - assert all([field in example.values.__fields__ for field in fields_to_check]) - assert example.values.test == 1 + assert issubclass(example.__class__, pydantic.BaseModel) + assert all([field in example.__fields__ for field in fields_to_check]) + assert example.test == 1 def test_sqlalchemy_table_is_created(example): - assert issubclass(example.__table__.__class__, sqlalchemy.Table) - assert all([field in example.__table__.columns for field in fields_to_check]) + assert issubclass(example.Meta.table.__class__, sqlalchemy.Table) + assert all([field in example.Meta.table.columns for field in fields_to_check]) def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - test_string = fields.String(length=250) + class Meta: + tablename = "example3" + metadata = metadata + + test_string: fields.String(max_length=250) def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - id = fields.Integer(primary_key=True) - test_string = fields.String(length=250, primary_key=True) + class Meta: + tablename = "example3" + metadata = metadata + + id: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250, primary_key=True) def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Integer(primary_key=True, pydantic_only=True) + class Meta: + tablename = "example4" + metadata = metadata + + test: fields.Integer(primary_key=True, pydantic_only=True) def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Decimal(primary_key=True) + class Meta: + tablename = "example5" + metadata = metadata + + test: fields.Decimal(primary_key=True) def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.String(primary_key=True) + class Meta: + tablename = "example6" + metadata = metadata + + test: fields.String(primary_key=True) def test_json_conversion_in_model(): diff --git a/tests/test_models.py b/tests/test_models.py index 69d421c..f21e70c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ import databases +import pydantic import pytest import sqlalchemy @@ -10,24 +11,36 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() -class User(ormar.Model): - __tablename__ = "users" - __metadata__ = metadata - __database__ = database +class JsonSample(ormar.Model): + class Meta: + tablename = "jsons" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + test_json: ormar.JSON(nullable=True) + + +class User(ormar.Model): + class Meta: + tablename = "users" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100, default='') class Product(ormar.Model): - __tablename__ = "product" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "product" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - rating = ormar.Integer(minimum=1, maximum=5) - in_stock = ormar.Boolean(default=False) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + rating: ormar.Integer(minimum=1, maximum=5) + in_stock: ormar.Boolean(default=False) @pytest.fixture(autouse=True, scope="module") @@ -39,12 +52,12 @@ def create_test_database(): def test_model_class(): - assert list(User.__model_fields__.keys()) == ["id", "name"] - assert isinstance(User.__model_fields__["id"], ormar.Integer) - assert User.__model_fields__["id"].primary_key is True - assert isinstance(User.__model_fields__["name"], ormar.String) - assert User.__model_fields__["name"].length == 100 - assert isinstance(User.__table__, sqlalchemy.Table) + assert list(User.Meta.model_fields.keys()) == ["id", "name"] + assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt) + assert User.Meta.model_fields["id"].primary_key is True + assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr) + assert User.Meta.model_fields["name"].max_length == 100 + assert isinstance(User.Meta.table, sqlalchemy.Table) def test_model_pk(): @@ -53,6 +66,18 @@ def test_model_pk(): assert user.id == 1 +@pytest.mark.asyncio +async def test_json_column(): + async with database: + await JsonSample.objects.create(test_json=dict(aa=12)) + await JsonSample.objects.create(test_json='{"aa": 12}') + + items = await JsonSample.objects.all() + assert len(items) == 2 + assert items[0].test_json == dict(aa=12) + assert items[1].test_json == dict(aa=12) + + @pytest.mark.asyncio async def test_model_crud(): async with database: diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index d7670cc..31e31b9 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -30,22 +30,24 @@ async def shutdown() -> None: class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Item(ormar.Model): - __tablename__ = "items" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "items" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) @pytest.fixture(autouse=True, scope="module") @@ -59,19 +61,19 @@ def create_test_database(): @app.get("/items/", response_model=List[Item]) async def get_items(): items = await Item.objects.select_related("category").all() - return [item.dict() for item in items] + return items @app.post("/items/", response_model=Item) async def create_item(item: Item): - item = await Item.objects.create(**item.dict()) - return item.dict() + await item.save() + return item @app.post("/categories/", response_model=Category) async def create_category(category: Category): - category = await Category.objects.create(**category.dict()) - return category.dict() + await category.save() + return category @app.put("/items/{item_id}") diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 60ddddc..13e2185 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -12,53 +12,58 @@ metadata = sqlalchemy.MetaData() class Department(ormar.Model): - __tablename__ = "departments" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "departments" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True, autoincrement=False) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True, autoincrement=False) + name: ormar.String(max_length=100) class SchoolClass(ormar.Model): - __tablename__ = "schoolclasses" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "schoolclasses" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - department = ormar.ForeignKey(Department, nullable=False) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + department: ormar.ForeignKey(Department, nullable=False) class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Student(ormar.Model): - __tablename__ = "students" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "students" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - schoolclass = ormar.ForeignKey(SchoolClass) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) class Teacher(ormar.Model): - __tablename__ = "teachers" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "teachers" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - schoolclass = ormar.ForeignKey(SchoolClass) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) @pytest.fixture(scope="module") @@ -71,6 +76,7 @@ def event_loop(): @pytest.fixture(autouse=True, scope="module") async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) metadata.create_all(engine) department = await Department.objects.create(id=1, name="Math Department") class1 = await SchoolClass.objects.create(name="Math", department=department)