From afa1756b472615cf0424fd409cf6cf200300e3e5 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 5 Oct 2021 18:50:02 +0200 Subject: [PATCH] very initial verson of construct --- ormar/models/newbasemodel.py | 64 +++++++++++++++-- .../test_model_construct.py | 71 +++++++++++++++++++ .../test_model_definition/test_save_status.py | 2 +- 3 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 tests/test_model_definition/test_model_construct.py diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 6b4eccf..ce64fa0 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -14,6 +14,7 @@ from typing import ( TYPE_CHECKING, Tuple, Type, + TypeVar, Union, cast, ) @@ -50,8 +51,11 @@ if TYPE_CHECKING: # pragma no cover from ormar.models import Model from ormar.signals import SignalEmitter + T = TypeVar("T", bound="NewBaseModel") + IntStr = Union[int, str] DictStrAny = Dict[str, Any] + SetStr = Set[str] AbstractSetIntStr = AbstractSet[IntStr] MappingIntStrAny = Mapping[IntStr, Any] @@ -154,7 +158,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass # register the columns models after initialization for related in self.extract_related_names().union(self.extract_through_names()): model_fields[related].expand_relationship( - new_kwargs.get(related), self, to_register=True, + new_kwargs.get(related), + self, + to_register=True, ) if hasattr(self, "_init_private_attributes"): @@ -261,7 +267,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass k, self._convert_json( k, - model_fields[k].expand_relationship(v, self, to_register=False,) + model_fields[k].expand_relationship( + v, + self, + to_register=False, + ) if k in model_fields else (v if k in pydantic_fields else model_fields[k]), ), @@ -315,7 +325,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass self, "_orm", RelationsManager( - related_fields=self.extract_related_fields(), owner=cast("Model", self), + related_fields=self.extract_related_fields(), + owner=cast("Model", self), ), ) @@ -488,7 +499,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass @staticmethod def _get_not_excluded_fields( - fields: Union[List, Set], include: Optional[Dict], exclude: Optional[Dict], + fields: Union[List, Set], + include: Optional[Dict], + exclude: Optional[Dict], ) -> List: """ Returns related field names applying on them include and exclude set. @@ -785,6 +798,49 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass data = data["__root__"] return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) + @classmethod + def construct( + cls: Type["T"], _fields_set: Optional["SetStr"] = None, **values: Any + ) -> "T": + own_values = { + k: v for k, v in values.items() if k not in cls.extract_related_names() + } + m = cls.__new__(cls) + fields_values: Dict[str, Any] = {} + for name, field in cls.__fields__.items(): + if name in own_values: + fields_values[name] = own_values[name] + elif not field.required: + fields_values[name] = field.get_default() + fields_values.update(own_values) + object.__setattr__(m, "__dict__", fields_values) + m._initialize_internal_attributes() + for relation in cls.extract_related_names(): + if relation in values: + relation_field = cls.Meta.model_fields[relation] + if isinstance(values[relation], list): + relation_value = [ + relation_field.to.construct(**x) if isinstance(x, dict) else x + for x in values[relation] + ] + else: + relation_value = [ + relation_field.to.construct(**values[relation]) + if isinstance(values[relation], dict) + else values[relation] + ] + + for child in relation_value: + m._orm.add( + parent=child, + child=m, + field=relation_field, + ) + if _fields_set is None: + _fields_set = set(values.keys()) + object.__setattr__(m, "__fields_set__", _fields_set) + return m + def update_from_dict(self, value_dict: Dict) -> "NewBaseModel": """ Updates self with values of fields passed in the dictionary. diff --git a/tests/test_model_definition/test_model_construct.py b/tests/test_model_definition/test_model_construct.py new file mode 100644 index 0000000..e27b768 --- /dev/null +++ b/tests/test_model_definition/test_model_construct.py @@ -0,0 +1,71 @@ +from typing import List + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class NickNames(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_init_and_construct_has_same_effect(): + async with database: + async with database.transaction(force_rollback=True): + hq = await HQ.objects.create(name="Main") + comp = Company(name="Banzai", hq=hq, founded=1988) + comp2 = Company.construct(**dict(name="Banzai", hq=hq, founded=1988)) + assert comp.dict() == comp2.dict() diff --git a/tests/test_model_definition/test_save_status.py b/tests/test_model_definition/test_save_status.py index 93e89ac..9762810 100644 --- a/tests/test_model_definition/test_save_status.py +++ b/tests/test_model_definition/test_save_status.py @@ -63,7 +63,7 @@ def create_test_database(): @pytest.mark.asyncio -async def test_instantation_false_save_true(): +async def test_instantiation_false_save_true(): async with database: async with database.transaction(force_rollback=True): comp = Company(name="Banzai", founded=1988)