From 4c1acc09ea3c1f3e04ad378f8fbb88818da8fdb9 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 11 Dec 2020 16:28:30 +0100 Subject: [PATCH] add test for select_related with limit --- ormar/models/metaclass.py | 103 +++++++++++++---------- tests/test_select_related_with_limit.py | 107 ++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 45 deletions(-) create mode 100644 tests/test_select_related_with_limit.py diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 879221e..d6fded6 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,6 +1,17 @@ import logging import warnings -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Type, Union, cast +from typing import ( + Any, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Tuple, + Type, + Union, + cast, +) import databases import pydantic @@ -50,7 +61,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -59,11 +70,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -74,7 +85,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -82,10 +93,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -100,7 +111,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, real_name=model.get_name(), ondelete="CASCADE" @@ -117,7 +128,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -139,7 +150,7 @@ def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField": def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -155,7 +166,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -165,7 +176,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -179,9 +190,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field.get_alias())) register_relation_in_alias_manager(table_name, field) @@ -189,7 +200,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -198,7 +209,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - ormar_field: Type[BaseField], field_name: str, attrs: dict + ormar_field: Type[BaseField], field_name: str, attrs: dict ) -> dict: curr_def_value = attrs.get(field_name, ormar.Undefined) if lenient_issubclass(curr_def_value, ormar.fields.BaseField): @@ -243,7 +254,7 @@ def extract_annotations_and_default_vals(attrs: dict) -> Tuple[Dict, Dict]: def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = ( @@ -269,7 +280,7 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: """ Constructs sqlalchemy table out of columns and parameters set on Meta class. @@ -360,7 +371,7 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001 def populate_default_options_values( - new_model: Type["Model"], model_fields: Dict + new_model: Type["Model"], model_fields: Dict ) -> None: """ Sets all optional Meta values to it's defaults @@ -479,11 +490,11 @@ def get_potential_fields(attrs: Dict) -> Dict: def check_conflicting_fields( - new_fields: Set, - attrs: Dict, - base_class: type, - curr_class: type, - previous_fields: Set = None, + new_fields: Set, + attrs: Dict, + base_class: type, + curr_class: type, + previous_fields: Set = None, ) -> None: """ You cannot redefine fields with same names in inherited classes. @@ -513,11 +524,11 @@ def check_conflicting_fields( def update_attrs_and_fields( - attrs: Dict, - new_attrs: Dict, - model_fields: Dict, - new_model_fields: Dict, - new_fields: Set, + attrs: Dict, + new_attrs: Dict, + model_fields: Dict, + new_model_fields: Dict, + new_fields: Set, ) -> None: """ Updates __annotations__, values of model fields (so pydantic FieldInfos) @@ -540,9 +551,7 @@ def update_attrs_and_fields( model_fields.update(new_model_fields) -def update_attrs_from_base_meta( - base_class: "Model", - attrs: Dict, ) -> None: +def update_attrs_from_base_meta(base_class: "Model", attrs: Dict,) -> None: """ Updates Meta parameters in child from parent if needed. @@ -564,18 +573,20 @@ def update_attrs_from_base_meta( curr_value.union(getattr(base_class.Meta, param)) else: # overwrite with child value if both set and its param / object - setattr(attrs["Meta"], param, getattr(base_class.Meta, param)) # pragma: no cover + setattr( + attrs["Meta"], param, getattr(base_class.Meta, param) + ) # pragma: no cover else: setattr(attrs["Meta"], param, getattr(base_class.Meta, param)) def extract_mixin_fields_from_dict( - base_class: type, - curr_class: type, - attrs: Dict, - model_fields: Dict[ - str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] - ], + base_class: type, + curr_class: type, + attrs: Dict, + model_fields: Dict[ + str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] + ], ) -> Tuple[Dict, Dict]: """ Extracts fields from base classes if they have valid oramr fields. @@ -603,7 +614,9 @@ def extract_mixin_fields_from_dict( if hasattr(base_class, "Meta"): if attrs.get("Meta"): new_fields = set(base_class.Meta.model_fields.keys()) # type: ignore - previous_fields = set({k for k, v in attrs.items() if isinstance(v, FieldInfo)}) + previous_fields = set( + {k for k, v in attrs.items() if isinstance(v, FieldInfo)} + ) check_conflicting_fields( new_fields=new_fields, attrs=attrs, @@ -675,7 +688,7 @@ def extract_mixin_fields_from_dict( class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name diff --git a/tests/test_select_related_with_limit.py b/tests/test_select_related_with_limit.py new file mode 100644 index 0000000..27e4460 --- /dev/null +++ b/tests/test_select_related_with_limit.py @@ -0,0 +1,107 @@ +from typing import List, Optional + +import databases +import sqlalchemy +from sqlalchemy import create_engine + +import ormar +import pytest + +from tests.settings import DATABASE_URL + +db = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Keyword(ormar.Model): + class Meta: + metadata = metadata + database = db + tablename = "keywords" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50) + + +class KeywordPrimaryModel(ormar.Model): + class Meta: + metadata = metadata + database = db + tablename = "primary_models_keywords" + + id: int = ormar.Integer(primary_key=True) + + +class PrimaryModel(ormar.Model): + class Meta: + metadata = metadata + database = db + tablename = "primary_models" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=255, index=True) + some_text: str = ormar.Text() + some_other_text: Optional[str] = ormar.Text(nullable=True) + keywords: Optional[List[Keyword]] = ormar.ManyToMany( + Keyword, through=KeywordPrimaryModel + ) + + +class SecondaryModel(ormar.Model): + class Meta: + metadata = metadata + database = db + tablename = "secondary_models" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + primary_model: PrimaryModel = ormar.ForeignKey( + PrimaryModel, + related_name="secondary_models", + ) + + +@pytest.mark.asyncio +async def test_create_primary_models(): + async with db: + for name, some_text, some_other_text in [ + ("Primary 1", "Some text 1", "Some other text 1"), + ("Primary 2", "Some text 2", "Some other text 2"), + ("Primary 3", "Some text 3", "Some other text 3"), + ("Primary 4", "Some text 4", "Some other text 4"), + ("Primary 5", "Some text 5", "Some other text 5"), + ("Primary 6", "Some text 6", "Some other text 6"), + ("Primary 7", "Some text 7", "Some other text 7"), + ("Primary 8", "Some text 8", "Some other text 8"), + ("Primary 9", "Some text 9", "Some other text 9"), + ("Primary 10", "Some text 10", "Some other text 10")]: + await PrimaryModel( + name=name, some_text=some_text, some_other_text=some_other_text + ).save() + + for tag_id in [1, 2, 3, 4, 5]: + await Keyword.objects.create(name=f"Tag {tag_id}") + + p1 = await PrimaryModel.objects.get(pk=1) + p2 = await PrimaryModel.objects.get(pk=2) + for i in range(1, 6): + keyword = await Keyword.objects.get(pk=i) + if i % 2 == 0: + await p1.keywords.add(keyword) + else: + await p2.keywords.add(keyword) + models = await PrimaryModel.objects.prefetch_related("keywords").limit(5).all() + + # This test fails, because of the keywords relation. + assert len(models) == 5 + assert len(models[0].keywords) == 2 + assert len(models[1].keywords) == 3 + assert len(models[2].keywords) == 0 + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine)