add tests for creation from dictionaries and for m2m relations

This commit is contained in:
collerek
2021-10-09 17:19:17 +02:00
parent 6d2712c0f8
commit 4896a3a982
2 changed files with 51 additions and 28 deletions

View File

@ -266,7 +266,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
self._convert_json( self._convert_json(
k, k,
model_fields[k].expand_relationship(v, self, to_register=False) model_fields[k].expand_relationship(v, self, to_register=False)
if k in model_fields if k in model_fields
else (v if k in pydantic_fields else model_fields[k]), else (v if k in pydantic_fields else model_fields[k]),
), ),
@ -797,7 +796,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
own_values = { own_values = {
k: v for k, v in values.items() if k not in cls.extract_related_names() k: v for k, v in values.items() if k not in cls.extract_related_names()
} }
m = cls.__new__(cls) model = cls.__new__(cls)
fields_values: Dict[str, Any] = {} fields_values: Dict[str, Any] = {}
for name, field in cls.__fields__.items(): for name, field in cls.__fields__.items():
if name in own_values: if name in own_values:
@ -805,33 +804,41 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
elif not field.required: elif not field.required:
fields_values[name] = field.get_default() fields_values[name] = field.get_default()
fields_values.update(own_values) fields_values.update(own_values)
object.__setattr__(m, "__dict__", fields_values) object.__setattr__(model, "__dict__", fields_values)
m._initialize_internal_attributes() model._initialize_internal_attributes()
for relation in cls.extract_related_names(): cls._construct_relations(model=model, values=values)
if relation in values: if _fields_set is None:
relation_field = cls.Meta.model_fields[relation] _fields_set = set(values.keys())
if isinstance(values[relation], list): object.__setattr__(model, "__fields_set__", _fields_set)
relation_value = [ return model
relation_field.to.construct(**x) if isinstance(x, dict) else x
for x in values[relation] @classmethod
def _construct_relations(cls: Type["T"], model: "T", values: Dict):
present_relations = [
relation for relation in cls.extract_related_names() if relation in values
] ]
else: for relation in present_relations:
value_to_set = values[relation]
if not isinstance(value_to_set, list):
value_to_set = [value_to_set]
relation_field = cls.Meta.model_fields[relation]
relation_value = [ relation_value = [
relation_field.to.construct(**values[relation]) cls.construct_from_dict_if_required(relation_field, value=x)
if isinstance(values[relation], dict) for x in value_to_set
else values[relation]
] ]
for child in relation_value: for child in relation_value:
m._orm.add( model._orm.add(
parent=child, parent=child,
child=m, child=model,
field=relation_field, field=relation_field,
) )
if _fields_set is None:
_fields_set = set(values.keys()) @staticmethod
object.__setattr__(m, "__fields_set__", _fields_set) def construct_from_dict_if_required(relation_field: "BaseField", value: Any):
return m return (
relation_field.to.construct(**value) if isinstance(value, dict) else value
)
def update_from_dict(self, value_dict: Dict) -> "NewBaseModel": def update_from_dict(self, value_dict: Dict) -> "NewBaseModel":
""" """

View File

@ -19,7 +19,6 @@ class NickNames(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name") name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
is_lame: bool = ormar.Boolean(nullable=True)
class NicksHq(ormar.Model): class NicksHq(ormar.Model):
@ -69,3 +68,20 @@ async def test_init_and_construct_has_same_effect():
comp = Company(name="Banzai", hq=hq, founded=1988) comp = Company(name="Banzai", hq=hq, founded=1988)
comp2 = Company.construct(**dict(name="Banzai", hq=hq, founded=1988)) comp2 = Company.construct(**dict(name="Banzai", hq=hq, founded=1988))
assert comp.dict() == comp2.dict() assert comp.dict() == comp2.dict()
comp3 = Company.construct(**dict(name="Banzai", hq=hq.dict(), founded=1988))
assert comp.dict() == comp3.dict()
@pytest.mark.asyncio
async def test_init_and_construct_has_same_effect_with_m2m():
async with database:
async with database.transaction(force_rollback=True):
n1 = await NickNames(name="test").save()
n2 = await NickNames(name="test2").save()
hq = HQ(name="Main", nicks=[n1, n2])
hq2 = HQ.construct(**dict(name="Main", nicks=[n1, n2]))
assert hq.dict() == hq2.dict()
hq3 = HQ.construct(**dict(name="Main", nicks=[n1.dict(), n2.dict()]))
assert hq.dict() == hq3.dict()