fix tests

This commit is contained in:
collerek
2021-03-15 10:37:55 +01:00
parent 6d0a5477cd
commit 5c633d32a8
5 changed files with 100 additions and 55 deletions

View File

@ -1,8 +1,6 @@
import uuid import uuid
from typing import Dict, Optional, Set, TYPE_CHECKING from typing import Dict, Optional, Set, TYPE_CHECKING
import pydantic
import ormar import ormar
from ormar.exceptions import ModelPersistenceError from ormar.exceptions import ModelPersistenceError
from ormar.models.helpers.validation import validate_choices from ormar.models.helpers.validation import validate_choices
@ -53,7 +51,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
pkname = cls.Meta.pkname pkname = cls.Meta.pkname
pk = cls.Meta.model_fields[pkname] pk = cls.Meta.model_fields[pkname]
if new_kwargs.get(pkname, ormar.Undefined) is None and ( if new_kwargs.get(pkname, ormar.Undefined) is None and (
pk.nullable or pk.autoincrement pk.nullable or pk.autoincrement
): ):
del new_kwargs[pkname] del new_kwargs[pkname]
return new_kwargs return new_kwargs
@ -71,10 +69,10 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
""" """
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if field.__type__ == uuid.UUID and name in model_dict: if field.__type__ == uuid.UUID and name in model_dict:
if field.column_type.uuid_format == "string": parsers = {"string": lambda x: str(x), "hex": lambda x: "%.32x" % x.int}
model_dict[name] = str(model_dict[name]) uuid_format = field.column_type.uuid_format
else: parser = parsers.get(uuid_format, lambda x: x)
model_dict[name] = "%.32x" % model_dict[name].int model_dict[name] = parser(model_dict[name])
return model_dict return model_dict
@classmethod @classmethod
@ -126,9 +124,9 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
""" """
for field_name, field in cls.Meta.model_fields.items(): for field_name, field in cls.Meta.model_fields.items():
if ( if (
field_name not in new_kwargs field_name not in new_kwargs
and field.has_default(use_server=False) and field.has_default(use_server=False)
and not field.pydantic_only and not field.pydantic_only
): ):
new_kwargs[field_name] = field.get_default() new_kwargs[field_name] = field.get_default()
# clear fields with server_default set as None # clear fields with server_default set as None

View File

@ -838,7 +838,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name]) model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs) return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
""" """
Returns all rows from a database for given model for set filter options. Returns all rows from a database for given model for set filter options.

View File

@ -127,7 +127,6 @@ class QuerysetProxy:
f"model without primary key set! \n" f"model without primary key set! \n"
f"Save the child model first." f"Save the child model first."
) )
print('final kwargs', final_kwargs)
await model_cls(**final_kwargs).save() await model_cls(**final_kwargs).save()
async def update_through_instance(self, child: "Model", **kwargs: Any) -> None: async def update_through_instance(self, child: "Model", **kwargs: Any) -> None:
@ -144,7 +143,6 @@ class QuerysetProxy:
child_column = self.related_field.default_source_field_name() # type: ignore child_column = self.related_field.default_source_field_name() # type: ignore
rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk} rel_kwargs = {owner_column: self._owner.pk, child_column: child.pk}
through_model = await model_cls.objects.get(**rel_kwargs) through_model = await model_cls.objects.get(**rel_kwargs)
print('update kwargs', kwargs)
await through_model.update(**kwargs) await through_model.update(**kwargs)
async def delete_through_instance(self, child: "Model") -> None: async def delete_through_instance(self, child: "Model") -> None:

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Type from typing import Any, Dict, List, Type, cast
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import databases import databases
@ -7,7 +7,7 @@ import sqlalchemy
import ormar import ormar
from ormar import ModelDefinitionError, Model, QuerySet, pre_update from ormar import ModelDefinitionError, Model, QuerySet, pre_update
from ormar import pre_save, pre_relation_add from ormar import pre_save
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL) database = databases.Database(DATABASE_URL)
@ -82,81 +82,133 @@ async def test_ordering_by_through_fail():
await alice.load_all() await alice.load_all()
def get_filtered_query( def _get_filtered_query(
sender: Type[Model], instance: Model, to_class: Type[Model] sender: Type[Model], instance: Model, to_class: Type[Model]
) -> QuerySet: ) -> QuerySet:
"""
Helper function.
Gets the query filtered by the appropriate class name.
"""
pk = getattr(instance, f"{to_class.get_name()}").pk pk = getattr(instance, f"{to_class.get_name()}").pk
filter_kwargs = {f"{to_class.get_name()}": pk} filter_kwargs = {f"{to_class.get_name()}": pk}
query = sender.objects.filter(**filter_kwargs) query = sender.objects.filter(**filter_kwargs)
return query return query
async def populate_order_on_insert( async def _populate_order_on_insert(
sender: Type[Model], instance: Model, from_class: Type[Model], sender: Type[Model], instance: Model, from_class: Type[Model], to_class: Type[Model]
to_class: Type[Model]
): ):
"""
Helper function.
Get max values from database for both orders and adds 1 (0 if max is None) if the
order is not provided. If the order is provided it reorders the existing links
to match the newly defined order.
Assumes names f"{model.get_name()}_order" like for Animal: animal_order.
"""
order_column = f"{from_class.get_name()}_order" order_column = f"{from_class.get_name()}_order"
if getattr(instance, order_column) is None: if getattr(instance, order_column) is None:
query = get_filtered_query(sender, instance, to_class) query = _get_filtered_query(sender, instance, to_class)
max_order = await query.max(order_column) max_order = await query.max(order_column)
max_order = max_order + 1 if max_order is not None else 0 max_order = max_order + 1 if max_order is not None else 0
setattr(instance, order_column, max_order) setattr(instance, order_column, max_order)
else: else:
await reorder_on_update(sender, instance, from_class, to_class, await _reorder_on_update(
passed_args={ sender=sender,
order_column: getattr(instance, order_column)}) instance=instance,
from_class=from_class,
to_class=to_class,
passed_args={order_column: getattr(instance, order_column)},
)
async def reorder_on_update( async def _reorder_on_update(
sender: Type[Model], instance: Model, from_class: Type[Model], sender: Type[Model],
to_class: Type[Model], passed_args: Dict instance: Model,
from_class: Type[Model],
to_class: Type[Model],
passed_args: Dict,
): ):
"""
Helper function.
Actually reorders links by given order passed in add/update query to the link
model.
Assumes names f"{model.get_name()}_order" like for Animal: animal_order.
"""
order = f"{from_class.get_name()}_order" order = f"{from_class.get_name()}_order"
if order in passed_args: if order in passed_args:
query = get_filtered_query(sender, instance, to_class) query = _get_filtered_query(sender, instance, to_class)
to_reorder = await query.exclude(pk=instance.pk).order_by(order).all() to_reorder = await query.exclude(pk=instance.pk).order_by(order).all()
old_order = getattr(instance, order)
new_order = passed_args.get(order) new_order = passed_args.get(order)
if to_reorder: if to_reorder and new_order is not None:
for link in to_reorder: # can be more efficient - here we renumber all even if not needed.
setattr(link, order, getattr(link, order) + 1) for ind, link in enumerate(to_reorder):
await sender.objects.bulk_update(to_reorder, columns=[order]) if ind < new_order:
check = await get_filtered_query(sender, instance, to_class).all() setattr(link, order, ind)
print('reordered', check) else:
setattr(link, order, ind + 1)
await sender.objects.bulk_update(
cast(List[Model], to_reorder), columns=[order]
)
@pre_save(Link) @pre_save(Link)
async def order_link_on_insert(sender: Type[Model], instance: Model, **kwargs: Any): async def order_link_on_insert(sender: Type[Model], instance: Model, **kwargs: Any):
"""
Signal receiver registered on Link model, triggered every time before one is created
by calling save() on a model. Note that signal functions for pre_save signal accepts
sender class, instance and have to accept **kwargs even if it's empty as of now.
"""
relations = list(instance.extract_related_names()) relations = list(instance.extract_related_names())
rel_one = sender.Meta.model_fields[relations[0]].to rel_one = sender.Meta.model_fields[relations[0]].to
rel_two = sender.Meta.model_fields[relations[1]].to rel_two = sender.Meta.model_fields[relations[1]].to
await populate_order_on_insert(sender, instance, from_class=rel_one, await _populate_order_on_insert(
to_class=rel_two) sender=sender, instance=instance, from_class=rel_one, to_class=rel_two
await populate_order_on_insert(sender, instance, from_class=rel_two, )
to_class=rel_one) await _populate_order_on_insert(
sender=sender, instance=instance, from_class=rel_two, to_class=rel_one
)
@pre_update(Link) @pre_update(Link)
async def reorder_links_on_update( async def reorder_links_on_update(
sender: Type[ormar.Model], instance: ormar.Model, passed_args: Dict, sender: Type[ormar.Model], instance: ormar.Model, passed_args: Dict, **kwargs: Any
**kwargs: Any
): ):
"""
Signal receiver registered on Link model, triggered every time before one is updated
by calling update() on a model. Note that signal functions for pre_update signal
accepts sender class, instance, passed_args which is a dict of kwargs passed to
update and have to accept **kwargs even if it's empty as of now.
"""
relations = list(instance.extract_related_names()) relations = list(instance.extract_related_names())
rel_one = sender.Meta.model_fields[relations[0]].to rel_one = sender.Meta.model_fields[relations[0]].to
rel_two = sender.Meta.model_fields[relations[1]].to rel_two = sender.Meta.model_fields[relations[1]].to
await reorder_on_update(sender, instance, from_class=rel_one, to_class=rel_two, await _reorder_on_update(
passed_args=passed_args) sender=sender,
await reorder_on_update(sender, instance, from_class=rel_two, to_class=rel_one, instance=instance,
passed_args=passed_args) from_class=rel_one,
to_class=rel_two,
passed_args=passed_args,
)
await _reorder_on_update(
sender=sender,
instance=instance,
from_class=rel_two,
to_class=rel_one,
passed_args=passed_args,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ordering_by_through_on_m2m_field(): async def test_ordering_by_through_on_m2m_field():
async with database: async with database:
def verify_order(instance, expected): def verify_order(instance, expected):
field_name = ( field_name = (
"favoriteAnimals" if isinstance(instance, "favoriteAnimals" if isinstance(instance, Human) else "favoriteHumans"
Human) else "favoriteHumans"
) )
assert [x.name for x in getattr(instance, field_name)] == expected assert [x.name for x in getattr(instance, field_name)] == expected
@ -196,25 +248,22 @@ async def test_ordering_by_through_on_m2m_field():
zack = await Human(name="Zack").save() zack = await Human(name="Zack").save()
await noodle.favoriteHumans.add(zack, animal_order=0, human_order=0) await noodle.favoriteHumans.add(zack, human_order=0)
await noodle.load_all() await noodle.load_all()
verify_order(noodle, ["Zack", "Alice", "Bob", "Charlie"]) verify_order(noodle, ["Zack", "Alice", "Bob", "Charlie"])
await zack.load_all() await zack.load_all()
verify_order(zack, ["Noodle"]) verify_order(zack, ["Noodle"])
await noodle.favoriteHumans.filter(name='Zack').update( await noodle.favoriteHumans.filter(name="Zack").update(link=dict(human_order=1))
link=dict(human_order=1))
await noodle.load_all() await noodle.load_all()
verify_order(noodle, ["Alice", "Zack", "Bob", "Charlie"]) verify_order(noodle, ["Alice", "Zack", "Bob", "Charlie"])
await noodle.favoriteHumans.filter(name='Zack').update( await noodle.favoriteHumans.filter(name="Zack").update(link=dict(human_order=2))
link=dict(human_order=2))
await noodle.load_all() await noodle.load_all()
verify_order(noodle, ["Alice", "Bob", "Zack", "Charlie"]) verify_order(noodle, ["Alice", "Bob", "Zack", "Charlie"])
await noodle.favoriteHumans.filter(name='Zack').update( await noodle.favoriteHumans.filter(name="Zack").update(link=dict(human_order=3))
link=dict(human_order=3))
await noodle.load_all() await noodle.load_all()
verify_order(noodle, ["Alice", "Bob", "Charlie", "Zack"]) verify_order(noodle, ["Alice", "Bob", "Charlie", "Zack"])

View File

@ -161,7 +161,7 @@ async def test_only_one_side_has_through() -> Any:
assert post2.categories[0].postcategory is not None assert post2.categories[0].postcategory is not None
categories = await Category.objects.select_related("posts").all() categories = await Category.objects.select_related("posts").all()
categories = cast(Sequence[Category], categories) assert isinstance(categories[0], Category)
assert categories[0].postcategory is None assert categories[0].postcategory is None
assert categories[0].posts[0].postcategory is not None assert categories[0].posts[0].postcategory is not None