fix qryset fields, model update, model delete, model load, qryset update, qruset filter, qryset bulk_load, qryset bulk_update

This commit is contained in:
collerek
2020-10-21 12:14:14 +02:00
parent 64fd9f3cce
commit dbca4367e8
7 changed files with 79 additions and 14 deletions

BIN
.coverage

Binary file not shown.

View File

@ -43,7 +43,6 @@ class Model(NewBaseModel):
if select_related: if select_related:
related_models = group_related_list(select_related) related_models = group_related_list(select_related)
# breakpoint()
if ( if (
previous_table previous_table
and previous_table in cls.Meta.model_fields and previous_table in cls.Meta.model_fields
@ -145,7 +144,8 @@ class Model(NewBaseModel):
self.from_dict(new_values) self.from_dict(new_values)
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
self_fields.pop(self.Meta.pkname) self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
self_fields = self.objects._translate_columns_to_aliases(self_fields)
expr = self.Meta.table.update().values(**self_fields) expr = self.Meta.table.update().values(**self_fields)
expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname))
@ -165,5 +165,7 @@ class Model(NewBaseModel):
raise ValueError( raise ValueError(
"Instance was deleted from database and cannot be refreshed" "Instance was deleted from database and cannot be refreshed"
) )
self.from_dict(dict(row)) kwargs = dict(row)
kwargs = self.objects._translate_aliases_to_columns(kwargs)
self.from_dict(kwargs)
return self return self

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Set, TYPE_CHECKING, Type, TypeVar, Union
import ormar import ormar
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
from ormar.fields import BaseField from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
@ -35,7 +35,7 @@ class ModelTableProxy:
return self_fields return self_fields
@classmethod @classmethod
def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: # noqa CCR001
for field in cls.extract_related_names(): for field in cls.extract_related_names():
field_value = model_dict.get(field, None) field_value = model_dict.get(field, None)
if field_value is not None: if field_value is not None:
@ -43,8 +43,10 @@ class ModelTableProxy:
target_pkname = target_field.to.Meta.pkname target_pkname = target_field.to.Meta.pkname
if isinstance(field_value, ormar.Model): if isinstance(field_value, ormar.Model):
model_dict[field] = getattr(field_value, target_pkname) model_dict[field] = getattr(field_value, target_pkname)
else: elif field_value:
model_dict[field] = field_value.get(target_pkname) model_dict[field] = field_value.get(target_pkname)
else:
model_dict.pop(field, None)
return model_dict return model_dict
@classmethod @classmethod
@ -76,6 +78,7 @@ class ModelTableProxy:
if ( if (
inspect.isclass(field) inspect.isclass(field)
and issubclass(field, ForeignKeyField) and issubclass(field, ForeignKeyField)
and not issubclass(field, ManyToManyField)
and not field.virtual and not field.virtual
): ):
related_names.add(name) related_names.add(name)
@ -98,7 +101,9 @@ class ModelTableProxy:
def _extract_model_db_fields(self) -> Dict: def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields() self_fields = self._extract_own_model_fields()
self_fields = { self_fields = {
k: v for k, v in self_fields.items() if k in self.Meta.table.columns k: v
for k, v in self_fields.items()
if self.get_column_alias(k) in self.Meta.table.columns
} }
for field in self._extract_db_related_names(): for field in self._extract_db_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
@ -139,7 +144,7 @@ class ModelTableProxy:
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows: List["Model"] = [] merged_rows: List["Model"] = []
for index, model in enumerate(result_rows): for index, model in enumerate(result_rows):
if index > 0 and model.pk == merged_rows[-1].pk: if index > 0 and model is not None and model.pk == merged_rows[-1].pk:
merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else: else:
merged_rows.append(model) merged_rows.append(model)
@ -179,6 +184,7 @@ class ModelTableProxy:
return column_names return column_names
if not nested: if not nested:
fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
columns = [ columns = [
name for name in fields if "__" not in name and name in column_names name for name in fields if "__" not in name and name in column_names
] ]
@ -189,6 +195,9 @@ class ModelTableProxy:
for name in fields for name in fields
if f"{model.get_name()}__" in name if f"{model.get_name()}__" in name
] ]
columns = [
model.get_column_alias(k) if not use_alias else k for k in columns
]
# if the model is in select and no columns in fields, all implied # if the model is in select and no columns in fields, all implied
if not columns: if not columns:
@ -197,7 +206,7 @@ class ModelTableProxy:
# always has to return pk column # always has to return pk column
pk_alias = ( pk_alias = (
model.get_column_alias(model.Meta.pkname) model.get_column_alias(model.Meta.pkname)
if use_alias if not use_alias
else model.Meta.pkname else model.Meta.pkname
) )
if pk_alias not in columns: if pk_alias not in columns:

View File

@ -83,7 +83,7 @@ class QueryClause:
else: else:
op = "exact" op = "exact"
column = self.table.columns[key] column = self.table.columns[self.model_cls.get_column_alias(key)]
table = self.table table = self.table
clause = self._process_column_clause_for_operator_and_value( clause = self._process_column_clause_for_operator_and_value(

View File

@ -110,7 +110,7 @@ class SqlJoin:
pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname) pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname)
self.order_bys.append(text(f"{alias}_{to_table}.{pkname_alias}")) self.order_bys.append(text(f"{alias}_{to_table}.{pkname_alias}"))
self_related_fields = model_cls.own_table_columns( self_related_fields = model_cls.own_table_columns(
model_cls, self.fields, nested=True model_cls, self.fields, nested=True,
) )
self.columns.extend( self.columns.extend(
self.relation_manager(model_cls).prefixed_columns( self.relation_manager(model_cls).prefixed_columns(

View File

@ -93,6 +93,12 @@ class QuerySet:
new_kwargs[field.name] = new_kwargs.pop(field_name) new_kwargs[field.name] = new_kwargs.pop(field_name)
return new_kwargs return new_kwargs
def _translate_aliases_to_columns(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_meta.model_fields.items():
if field.name in new_kwargs and field.name != field_name:
new_kwargs[field_name] = new_kwargs.pop(field.name)
return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_meta.pkname pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname] pk = self.model_meta.model_fields[pkname]
@ -201,6 +207,7 @@ class QuerySet:
async def update(self, each: bool = False, **kwargs: Any) -> int: async def update(self, each: bool = False, **kwargs: Any) -> int:
self_fields = self.model.extract_db_own_fields() self_fields = self.model.extract_db_own_fields()
updates = {k: v for k, v in kwargs.items() if k in self_fields} updates = {k: v for k, v in kwargs.items() if k in self_fields}
updates = self._translate_columns_to_aliases(updates)
if not each and not self.filter_clauses: if not each and not self.filter_clauses:
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot update without filtering the queryset first. " "You cannot update without filtering the queryset first. "
@ -336,6 +343,8 @@ class QuerySet:
if pk_name not in columns: if pk_name not in columns:
columns.append(pk_name) columns.append(pk_name)
columns = [self.model.get_column_alias(k) for k in columns]
for objt in objects: for objt in objects:
new_kwargs = objt.dict() new_kwargs = objt.dict()
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
@ -344,13 +353,22 @@ class QuerySet:
f"{self.model.__name__} has to have {pk_name} filled." f"{self.model.__name__} has to have {pk_name} filled."
) )
new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._translate_columns_to_aliases(new_kwargs)
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs) ready_objects.append(new_kwargs)
pk_column = self.model_meta.table.c.get(pk_name) pk_column = self.model_meta.table.c.get(pk_name)
expr = self.table.update().where(pk_column == bindparam("new_" + pk_name)) pk_column_name = self.model.get_column_alias(pk_name)
table_columns = [c.name for c in self.model_meta.table.c]
expr = self.table.update().where(
pk_column == bindparam("new_" + pk_column_name)
)
expr = expr.values( expr = expr.values(
**{k: bindparam("new_" + k) for k in columns if k != pk_name} **{
k: bindparam("new_" + k)
for k in columns
if k != pk_column_name and k in table_columns
}
) )
# databases bind params only where query is passed as string # databases bind params only where query is passed as string
# otherwise it just pases all data to values and results in unconsumed columns # otherwise it just pases all data to values and results in unconsumed columns

View File

@ -18,7 +18,7 @@ class Child(ormar.Model):
id: ormar.Integer(name='child_id', primary_key=True) id: ormar.Integer(name='child_id', primary_key=True)
first_name: ormar.String(name='fname', max_length=100) first_name: ormar.String(name='fname', max_length=100)
last_name: ormar.String(name='lname', max_length=100) last_name: ormar.String(name='lname', max_length=100)
born_year: ormar.Integer(name='year_born') born_year: ormar.Integer(name='year_born', nullable=True)
class ArtistChildren(ormar.Model): class ArtistChildren(ormar.Model):
@ -93,3 +93,39 @@ async def test_working_with_aliases():
assert artist.children[0].first_name == 'Son' assert artist.children[0].first_name == 'Son'
assert artist.children[1].last_name == '2' assert artist.children[1].last_name == '2'
await artist.update(last_name='Bundy')
await Artist.objects.filter(pk=artist.pk).update(born_year=1974)
artist = await Artist.objects.select_related('children').get()
assert artist.last_name == 'Bundy'
assert artist.born_year == 1974
artist = await Artist.objects.select_related('children').fields(
['first_name', 'last_name', 'born_year', 'child__first_name', 'child__last_name']).get()
assert artist.children[0].born_year is None
@pytest.mark.asyncio
async def test_bulk_operations_and_fields():
async with database:
d1 = Child(first_name='Daughter', last_name='1', born_year=1990)
d2 = Child(first_name='Daughter', last_name='2', born_year=1991)
await Child.objects.bulk_create([d1, d2])
children = await Child.objects.filter(first_name='Daughter').all()
assert len(children) == 2
assert children[0].last_name == '1'
for child in children:
child.born_year = child.born_year - 100
await Child.objects.bulk_update(children)
children = await Child.objects.fields(['first_name', 'last_name']).all()
assert len(children) == 2
for child in children:
assert child.born_year is None
await children[0].load()
await children[0].delete()
children = await Child.objects.all()