diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 924308d..ab105cd 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -266,7 +266,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass self._convert_json( k, model_fields[k].expand_relationship(v, self, to_register=False) - if k in model_fields else (v if k in pydantic_fields else model_fields[k]), ), @@ -797,7 +796,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass own_values = { k: v for k, v in values.items() if k not in cls.extract_related_names() } - m = cls.__new__(cls) + model = cls.__new__(cls) fields_values: Dict[str, Any] = {} for name, field in cls.__fields__.items(): if name in own_values: @@ -805,33 +804,41 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass elif not field.required: fields_values[name] = field.get_default() fields_values.update(own_values) - object.__setattr__(m, "__dict__", fields_values) - m._initialize_internal_attributes() - for relation in cls.extract_related_names(): - if relation in values: - relation_field = cls.Meta.model_fields[relation] - if isinstance(values[relation], list): - relation_value = [ - relation_field.to.construct(**x) if isinstance(x, dict) else x - for x in values[relation] - ] - else: - relation_value = [ - relation_field.to.construct(**values[relation]) - if isinstance(values[relation], dict) - else values[relation] - ] - - for child in relation_value: - m._orm.add( - parent=child, - child=m, - field=relation_field, - ) + object.__setattr__(model, "__dict__", fields_values) + model._initialize_internal_attributes() + cls._construct_relations(model=model, values=values) if _fields_set is None: _fields_set = set(values.keys()) - object.__setattr__(m, "__fields_set__", _fields_set) - return m + object.__setattr__(model, "__fields_set__", _fields_set) + return model + + @classmethod + def _construct_relations(cls: Type["T"], model: "T", values: Dict): + present_relations = [ + relation for relation in cls.extract_related_names() if relation in values + ] + for relation in present_relations: + value_to_set = values[relation] + if not isinstance(value_to_set, list): + value_to_set = [value_to_set] + relation_field = cls.Meta.model_fields[relation] + relation_value = [ + cls.construct_from_dict_if_required(relation_field, value=x) + for x in value_to_set + ] + + for child in relation_value: + model._orm.add( + parent=child, + child=model, + field=relation_field, + ) + + @staticmethod + def construct_from_dict_if_required(relation_field: "BaseField", value: Any): + return ( + relation_field.to.construct(**value) if isinstance(value, dict) else value + ) def update_from_dict(self, value_dict: Dict) -> "NewBaseModel": """ diff --git a/tests/test_model_definition/test_model_construct.py b/tests/test_model_definition/test_model_construct.py index e27b768..87bd4ba 100644 --- a/tests/test_model_definition/test_model_construct.py +++ b/tests/test_model_definition/test_model_construct.py @@ -19,7 +19,6 @@ class NickNames(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") - is_lame: bool = ormar.Boolean(nullable=True) class NicksHq(ormar.Model): @@ -69,3 +68,20 @@ async def test_init_and_construct_has_same_effect(): comp = Company(name="Banzai", hq=hq, founded=1988) comp2 = Company.construct(**dict(name="Banzai", hq=hq, founded=1988)) assert comp.dict() == comp2.dict() + + comp3 = Company.construct(**dict(name="Banzai", hq=hq.dict(), founded=1988)) + assert comp.dict() == comp3.dict() + + +@pytest.mark.asyncio +async def test_init_and_construct_has_same_effect_with_m2m(): + async with database: + async with database.transaction(force_rollback=True): + n1 = await NickNames(name="test").save() + n2 = await NickNames(name="test2").save() + hq = HQ(name="Main", nicks=[n1, n2]) + hq2 = HQ.construct(**dict(name="Main", nicks=[n1, n2])) + assert hq.dict() == hq2.dict() + + hq3 = HQ.construct(**dict(name="Main", nicks=[n1.dict(), n2.dict()])) + assert hq.dict() == hq3.dict()