From dbca4367e8a793264d10c0b5f0c6aa5c70920065 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 21 Oct 2020 12:14:14 +0200 Subject: [PATCH] fix qryset fields, model update, model delete, model load, qryset update, qruset filter, qryset bulk_load, qryset bulk_update --- .coverage | Bin 53248 -> 53248 bytes ormar/models/model.py | 8 +++++--- ormar/models/modelproxy.py | 21 ++++++++++++++------ ormar/queryset/clause.py | 2 +- ormar/queryset/join.py | 2 +- ormar/queryset/queryset.py | 22 +++++++++++++++++++-- tests/test_aliases.py | 38 ++++++++++++++++++++++++++++++++++++- 7 files changed, 79 insertions(+), 14 deletions(-) diff --git a/.coverage b/.coverage index 0c5ad2c27ffd51e1fce496b83c5eb1ce6a6b9299..55690d1026e0b04eb51cb77f7ea3a791d917dc8f 100644 GIT binary patch delta 274 zcmV+t0qy>PpaX!Q1F$?V3owe#E+rIto{;TippWT1%KkxlnKkGmLxx0Jczq{J@_1k}}&-=5!jwk^b76JqT2`&O| Y^F5#M^PN9GzwLUv^Skv9v-yt)K#F&Hy#N3J delta 266 zcmV+l0rmcXpaX!Q1F$?V3o$exF*rIfI65#jvqvw#P+SrK5BU$~57-aQ55W(w50wvy z4_FUN4?_u5U>lg3XF;YGM_hRo9+LvtNpaE_TPK^9s?g9 z1OW*^9`;2h@9V$)et-Y=zrF8|;Sc@xd$0fPe9SwO$Bn%+G7|&=31$ Dict: + def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: # noqa CCR001 for field in cls.extract_related_names(): field_value = model_dict.get(field, None) if field_value is not None: @@ -43,8 +43,10 @@ class ModelTableProxy: target_pkname = target_field.to.Meta.pkname if isinstance(field_value, ormar.Model): model_dict[field] = getattr(field_value, target_pkname) - else: + elif field_value: model_dict[field] = field_value.get(target_pkname) + else: + model_dict.pop(field, None) return model_dict @classmethod @@ -76,6 +78,7 @@ class ModelTableProxy: if ( inspect.isclass(field) and issubclass(field, ForeignKeyField) + and not issubclass(field, ManyToManyField) and not field.virtual ): related_names.add(name) @@ -98,7 +101,9 @@ class ModelTableProxy: def _extract_model_db_fields(self) -> Dict: self_fields = self._extract_own_model_fields() self_fields = { - k: v for k, v in self_fields.items() if k in self.Meta.table.columns + k: v + for k, v in self_fields.items() + if self.get_column_alias(k) in self.Meta.table.columns } for field in self._extract_db_related_names(): target_pk_name = self.Meta.model_fields[field].to.Meta.pkname @@ -139,7 +144,7 @@ class ModelTableProxy: def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: merged_rows: List["Model"] = [] for index, model in enumerate(result_rows): - if index > 0 and model.pk == merged_rows[-1].pk: + if index > 0 and model is not None and model.pk == merged_rows[-1].pk: merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) else: merged_rows.append(model) @@ -179,6 +184,7 @@ class ModelTableProxy: return column_names if not nested: + fields = [model.get_column_alias(k) if not use_alias else k for k in fields] columns = [ name for name in fields if "__" not in name and name in column_names ] @@ -189,6 +195,9 @@ class ModelTableProxy: for name in fields if f"{model.get_name()}__" in name ] + columns = [ + model.get_column_alias(k) if not use_alias else k for k in columns + ] # if the model is in select and no columns in fields, all implied if not columns: @@ -197,7 +206,7 @@ class ModelTableProxy: # always has to return pk column pk_alias = ( model.get_column_alias(model.Meta.pkname) - if use_alias + if not use_alias else model.Meta.pkname ) if pk_alias not in columns: diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index c5d3cd4..362ba85 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -83,7 +83,7 @@ class QueryClause: else: op = "exact" - column = self.table.columns[key] + column = self.table.columns[self.model_cls.get_column_alias(key)] table = self.table clause = self._process_column_clause_for_operator_and_value( diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 42803ff..c045e16 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -110,7 +110,7 @@ class SqlJoin: pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname) self.order_bys.append(text(f"{alias}_{to_table}.{pkname_alias}")) self_related_fields = model_cls.own_table_columns( - model_cls, self.fields, nested=True + model_cls, self.fields, nested=True, ) self.columns.extend( self.relation_manager(model_cls).prefixed_columns( diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 676374e..4d421dd 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -93,6 +93,12 @@ class QuerySet: new_kwargs[field.name] = new_kwargs.pop(field_name) return new_kwargs + def _translate_aliases_to_columns(self, new_kwargs: dict) -> dict: + for field_name, field in self.model_meta.model_fields.items(): + if field.name in new_kwargs and field.name != field_name: + new_kwargs[field_name] = new_kwargs.pop(field.name) + return new_kwargs + def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: pkname = self.model_meta.pkname pk = self.model_meta.model_fields[pkname] @@ -201,6 +207,7 @@ class QuerySet: async def update(self, each: bool = False, **kwargs: Any) -> int: self_fields = self.model.extract_db_own_fields() updates = {k: v for k, v in kwargs.items() if k in self_fields} + updates = self._translate_columns_to_aliases(updates) if not each and not self.filter_clauses: raise QueryDefinitionError( "You cannot update without filtering the queryset first. " @@ -336,6 +343,8 @@ class QuerySet: if pk_name not in columns: columns.append(pk_name) + columns = [self.model.get_column_alias(k) for k in columns] + for objt in objects: new_kwargs = objt.dict() if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: @@ -344,13 +353,22 @@ class QuerySet: f"{self.model.__name__} has to have {pk_name} filled." ) new_kwargs = self.model.substitute_models_with_pks(new_kwargs) + new_kwargs = self._translate_columns_to_aliases(new_kwargs) new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} ready_objects.append(new_kwargs) pk_column = self.model_meta.table.c.get(pk_name) - expr = self.table.update().where(pk_column == bindparam("new_" + pk_name)) + pk_column_name = self.model.get_column_alias(pk_name) + table_columns = [c.name for c in self.model_meta.table.c] + expr = self.table.update().where( + pk_column == bindparam("new_" + pk_column_name) + ) expr = expr.values( - **{k: bindparam("new_" + k) for k in columns if k != pk_name} + **{ + k: bindparam("new_" + k) + for k in columns + if k != pk_column_name and k in table_columns + } ) # databases bind params only where query is passed as string # otherwise it just pases all data to values and results in unconsumed columns diff --git a/tests/test_aliases.py b/tests/test_aliases.py index f8359d9..ab0ecda 100644 --- a/tests/test_aliases.py +++ b/tests/test_aliases.py @@ -18,7 +18,7 @@ class Child(ormar.Model): id: ormar.Integer(name='child_id', primary_key=True) first_name: ormar.String(name='fname', max_length=100) last_name: ormar.String(name='lname', max_length=100) - born_year: ormar.Integer(name='year_born') + born_year: ormar.Integer(name='year_born', nullable=True) class ArtistChildren(ormar.Model): @@ -93,3 +93,39 @@ async def test_working_with_aliases(): assert artist.children[0].first_name == 'Son' assert artist.children[1].last_name == '2' + await artist.update(last_name='Bundy') + await Artist.objects.filter(pk=artist.pk).update(born_year=1974) + + artist = await Artist.objects.select_related('children').get() + assert artist.last_name == 'Bundy' + assert artist.born_year == 1974 + + artist = await Artist.objects.select_related('children').fields( + ['first_name', 'last_name', 'born_year', 'child__first_name', 'child__last_name']).get() + assert artist.children[0].born_year is None + + +@pytest.mark.asyncio +async def test_bulk_operations_and_fields(): + async with database: + d1 = Child(first_name='Daughter', last_name='1', born_year=1990) + d2 = Child(first_name='Daughter', last_name='2', born_year=1991) + await Child.objects.bulk_create([d1, d2]) + + children = await Child.objects.filter(first_name='Daughter').all() + assert len(children) == 2 + assert children[0].last_name == '1' + + for child in children: + child.born_year = child.born_year - 100 + + await Child.objects.bulk_update(children) + + children = await Child.objects.fields(['first_name', 'last_name']).all() + assert len(children) == 2 + for child in children: + assert child.born_year is None + + await children[0].load() + await children[0].delete() + children = await Child.objects.all()