diff --git a/.coverage b/.coverage index ed0ddb4..a7dc6e2 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 2448fde..9d052a7 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -23,7 +23,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def ForeignKey( - to: "Model", + to: Type["Model"], *, name: str = None, unique: bool = False, diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index d3f41cf..dd81838 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,11 +1,9 @@ import datetime import decimal -import re from typing import Any, Optional, Type import pydantic import sqlalchemy -from pydantic import Json from ormar import ModelDefinitionError # noqa I101 from ormar.fields.base import BaseField # noqa I101 @@ -19,359 +17,253 @@ def is_field_nullable( return nullable -def String( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - allow_blank: bool = False, - strip_whitespace: bool = False, - min_length: int = None, - max_length: int = None, - curtail_length: int = None, - regex: str = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[str]: - if max_length is None or max_length <= 0: - raise ModelDefinitionError("Parameter max_length is required for field String") +class ModelFieldFactory: + _bases = None + _type = None - namespace = dict( - __type__=str, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - allow_blank=allow_blank, - strip_whitespace=strip_whitespace, - min_length=min_length, - max_length=max_length, - curtail_length=curtail_length, - regex=regex and re.compile(regex), - column_type=sqlalchemy.String(length=max_length), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) + def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: + cls.validate(**kwargs) - return type("String", (pydantic.ConstrainedStr, BaseField), namespace) + default = kwargs.pop("default", None) + server_default = kwargs.pop("server_default", None) + nullable = kwargs.pop("nullable", None) - -def Integer( - *, - name: str = None, - primary_key: bool = False, - autoincrement: bool = None, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[int]: - namespace = dict( - __type__=int, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - ge=minimum, - le=maximum, - multiple_of=multiple_of, - column_type=sqlalchemy.Integer(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=autoincrement if autoincrement is not None else primary_key, - ) - return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) - - -def Text( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - allow_blank: bool = False, - strip_whitespace: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[str]: - namespace = dict( - __type__=str, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - allow_blank=allow_blank, - strip_whitespace=strip_whitespace, - column_type=sqlalchemy.Text(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - - return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) - - -def Float( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[int]: - namespace = dict( - __type__=float, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - ge=minimum, - le=maximum, - multiple_of=multiple_of, - column_type=sqlalchemy.Float(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) - - -def Boolean( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[bool]: - namespace = dict( - __type__=bool, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.Boolean(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Boolean", (int, BaseField), namespace) - - -def DateTime( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[datetime.datetime]: - namespace = dict( - __type__=datetime.datetime, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.DateTime(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("DateTime", (datetime.datetime, BaseField), namespace) - - -def Date( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[datetime.date]: - namespace = dict( - __type__=datetime.date, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.Date(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Date", (datetime.date, BaseField), namespace) - - -def Time( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[datetime.time]: - namespace = dict( - __type__=datetime.time, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.Time(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Time", (datetime.time, BaseField), namespace) - - -def JSON( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[Json]: - namespace = dict( - __type__=pydantic.Json, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - column_type=sqlalchemy.JSON(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - - return type("JSON", (pydantic.Json, BaseField), namespace) - - -def BigInteger( - *, - name: str = None, - primary_key: bool = False, - autoincrement: bool = None, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[int]: - namespace = dict( - __type__=int, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - ge=minimum, - le=maximum, - multiple_of=multiple_of, - column_type=sqlalchemy.BigInteger(), - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=autoincrement if autoincrement is not None else primary_key, - ) - return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) - - -def Decimal( - *, - name: str = None, - primary_key: bool = False, - nullable: bool = None, - index: bool = False, - unique: bool = False, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - precision: int = None, - scale: int = None, - max_digits: int = None, - decimal_places: int = None, - pydantic_only: bool = False, - default: Any = None, - server_default: Any = None, -) -> Type[decimal.Decimal]: - if precision is None or precision < 0 or scale is None or scale < 0: - raise ModelDefinitionError( - "Parameters scale and precision are required for field Decimal" + namespace = dict( + __type__=cls._type, + name=kwargs.pop("name", None), + primary_key=kwargs.pop("primary_key", False), + default=default, + server_default=server_default, + nullable=is_field_nullable(nullable, default, server_default), + index=kwargs.pop("index", False), + unique=kwargs.pop("unique", False), + pydantic_only=kwargs.pop("pydantic_only", False), + autoincrement=kwargs.pop("autoincrement", False), + column_type=cls.get_column_type(**kwargs), + **kwargs ) + return type(cls.__name__, cls._bases, namespace) - namespace = dict( - __type__=decimal.Decimal, - name=name, - primary_key=primary_key, - nullable=is_field_nullable(nullable, default, server_default), - index=index, - unique=unique, - ge=minimum, - le=maximum, - multiple_of=multiple_of, - column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale), - precision=precision, - scale=scale, - max_digits=max_digits, - decimal_places=decimal_places, - pydantic_only=pydantic_only, - default=default, - server_default=server_default, - autoincrement=False, - ) - return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: # pragma no cover + return None + + @classmethod + def validate(cls, **kwargs: Any) -> None: # pragma no cover + pass + + +class String(ModelFieldFactory): + _bases = (pydantic.ConstrainedStr, BaseField) + _type = str + + def __new__( + cls, + *, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + **kwargs: Any + ) -> Type[str]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.String(length=kwargs.get("max_length")) + + @classmethod + def validate(cls, **kwargs: Any) -> None: + max_length = kwargs.get("max_length", None) + if max_length is None or max_length <= 0: + raise ModelDefinitionError( + "Parameter max_length is required for field String" + ) + + +class Integer(ModelFieldFactory): + _bases = (pydantic.ConstrainedInt, BaseField) + _type = int + + def __new__( + cls, + *, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + **kwargs: Any + ) -> Type[int]: + autoincrement = kwargs.pop("autoincrement", None) + autoincrement = ( + autoincrement + if autoincrement is not None + else kwargs.get("primary_key", False) + ) + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Integer() + + +class Text(ModelFieldFactory): + _bases = (pydantic.ConstrainedStr, BaseField) + _type = str + + def __new__( + cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any + ) -> Type[str]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Text() + + +class Float(ModelFieldFactory): + _bases = (pydantic.ConstrainedFloat, BaseField) + _type = float + + def __new__( + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + **kwargs: Any + ) -> Type[int]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Float() + + +class Boolean(ModelFieldFactory): + _bases = (int, BaseField) + _type = bool + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Boolean() + + +class DateTime(ModelFieldFactory): + _bases = (datetime.datetime, BaseField) + _type = datetime.datetime + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.DateTime() + + +class Date(ModelFieldFactory): + _bases = (datetime.date, BaseField) + _type = datetime.date + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Date() + + +class Time(ModelFieldFactory): + _bases = (datetime.time, BaseField) + _type = datetime.time + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.Time() + + +class JSON(ModelFieldFactory): + _bases = (pydantic.Json, BaseField) + _type = pydantic.Json + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.JSON() + + +class BigInteger(Integer): + _bases = (pydantic.ConstrainedInt, BaseField) + _type = int + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy.BigInteger() + + +class Decimal(ModelFieldFactory): + _bases = (pydantic.ConstrainedDecimal, BaseField) + _type = decimal.Decimal + + def __new__( + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + **kwargs: Any + ) -> Type[decimal.Decimal]: + kwargs = { + **kwargs, + **{ + k: v + for k, v in locals().items() + if k not in ["cls", "__class__", "kwargs"] + }, + } + return super().__new__(cls, **kwargs) + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + precision = kwargs.get("precision") + scale = kwargs.get("scale") + return sqlalchemy.DECIMAL(precision=precision, scale=scale) + + @classmethod + def validate(cls, **kwargs: Any) -> None: + precision = kwargs.get("precision") + scale = kwargs.get("scale") + if precision is None or precision < 0 or scale is None or scale < 0: + raise ModelDefinitionError( + "Parameters scale and precision are required for field Decimal" + ) diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index d6958fd..3b74c9d 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -240,10 +240,9 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): k: v for k, v in self_fields.items() if k in self.Meta.table.columns } for field in self._extract_related_names(): + target_pk_name = self.Meta.model_fields[field].to.Meta.pkname if getattr(self, field) is not None: - self_fields[field] = getattr( - getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname - ) + self_fields[field] = getattr(getattr(self, field), target_pk_name) return self_fields @classmethod @@ -259,17 +258,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): @classmethod def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.Meta.model_fields.keys(): - if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), ormar.Model + current_field = getattr(one, field) + if isinstance(current_field, list) and not isinstance( + current_field, ormar.Model ): - setattr(other, field, getattr(one, field) + getattr(other, field)) - elif isinstance(getattr(one, field), ormar.Model): - if getattr(one, field).pk == getattr(other, field).pk: - setattr( - other, - field, - cls.merge_two_instances( - getattr(one, field), getattr(other, field) - ), - ) + setattr(other, field, current_field + getattr(other, field)) + elif ( + isinstance(current_field, ormar.Model) + and current_field.pk == getattr(other, field).pk + ): + setattr( + other, + field, + cls.merge_two_instances(current_field, getattr(other, field)), + ) return other diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 166c943..13e2185 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -76,6 +76,7 @@ def event_loop(): @pytest.fixture(autouse=True, scope="module") async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) metadata.create_all(engine) department = await Department.objects.create(id=1, name="Math Department") class1 = await SchoolClass.objects.create(name="Math", department=department)