diff --git a/.coverage b/.coverage index 0de5a48..6db7a66 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.flake8 b/.flake8 index 9976335..af173f2 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -ignore = ANN101, ANN102, W503 +ignore = ANN101, ANN102, W503, S101 max-complexity = 8 max-line-length = 88 exclude = p38venv,.pytest_cache diff --git a/orm/exceptions.py b/orm/exceptions.py index cb2100e..40cfd26 100644 --- a/orm/exceptions.py +++ b/orm/exceptions.py @@ -18,5 +18,9 @@ class MultipleMatches(AsyncOrmException): pass +class QueryDefinitionError(AsyncOrmException): + pass + + class RelationshipInstanceError(AsyncOrmException): pass diff --git a/orm/models.py b/orm/models.py index e0aaa25..c16b21d 100644 --- a/orm/models.py +++ b/orm/models.py @@ -46,13 +46,11 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = field_name if isinstance(field, ForeignKey): - reverse_name = ( - field.related_name - or field.to.__name__.lower().title() + "_" + name.lower() + "s" - ) - relation_name = ( - name.lower().title() + "_" + field.to.__name__.lower() + child_relation_name = ( + field.to.get_name(title=True) + "_" + name.lower() + "s" ) + reverse_name = field.related_name or child_relation_name + relation_name = name.lower().title() + "_" + field.to.get_name() relationship_manager.add_relation_type( relation_name, reverse_name, field, tablename ) @@ -241,6 +239,15 @@ class Model(list, metaclass=ModelMetaclass): # def schema(cls, by_alias: bool = True): # pragma no cover # return cls.__pydantic_model__.schema(by_alias=by_alias) + @classmethod + def get_name(cls, title: bool = False, lower: bool = True) -> str: + name = cls.__name__ + if lower: + name = name.lower() + if title: + name = name.title() + return name + def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @@ -256,7 +263,7 @@ class Model(list, metaclass=ModelMetaclass): def pk_column(self) -> sqlalchemy.Column: return self.__table__.primary_key.columns.values()[0] - def dict(self) -> Dict: + def dict(self) -> Dict: # noqa: A003 dict_instance = self.values.dict() for field in self.extract_related_names(): nested_model = getattr(self, field) diff --git a/orm/queryset.py b/orm/queryset.py index d66d50a..1b5a5dc 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,10 +1,20 @@ -from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union +from typing import ( + Any, + Dict, + List, + NamedTuple, + Optional, + TYPE_CHECKING, + Tuple, + Type, + Union, +) import databases import orm from orm import ForeignKey -from orm.exceptions import MultipleMatches, NoMatch +from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError from orm.fields import BaseField import sqlalchemy @@ -80,18 +90,11 @@ class QuerySet: return text(f"{name} {alias}_{name}") def on_clause( - self, - from_table: str, - to_table: str, - previous_alias: str, - alias: str, - to_key: str, - from_key: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: - return text( - f"{alias}_{to_table}.{to_key}=" - f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' - ) + left_part = f"{alias}_{to_clause}" + right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" + return text(f"{left_part}={right_part}") def _build_join_parameters( self, part: str, join_params: JoinParameters @@ -118,12 +121,10 @@ class QuerySet: from_key = part on_clause = self.on_clause( - join_params.from_table, - to_table, - join_params.previous_alias, - alias, - to_key, - from_key, + previous_alias=join_params.previous_alias, + alias=alias, + from_clause=f"{join_params.from_table}.{from_key}", + to_clause=f"{to_table}.{to_key}", ) target_table = self.prefixed_table_name(alias, to_table) self.select_from = sqlalchemy.sql.outerjoin( @@ -159,12 +160,12 @@ class QuerySet: def _extract_auto_required_relations( self, - join_params: JoinParameters, + prev_model: Type["Model"], rel_part: str = "", nested: bool = False, parent_virtual: bool = False, ) -> None: - for field_name, field in join_params.prev_model.__model_fields__.items(): + for field_name, field in prev_model.__model_fields__.items(): if self._field_is_a_foreign_key_and_no_circular_reference( field, field_name, rel_part ): @@ -176,14 +177,8 @@ class QuerySet: elif self._field_qualifies_to_deeper_search( field, parent_virtual, nested, rel_part ): - join_params = JoinParameters( - field.to, - join_params.previous_alias, - join_params.from_table, - join_params.prev_model, - ) self._extract_auto_required_relations( - join_params=join_params, + prev_model=field.to, rel_part=rel_part, nested=True, parent_virtual=field.virtual, @@ -244,7 +239,7 @@ class QuerySet: start_params = JoinParameters( self.model_cls, "", self.table.name, self.model_cls ) - self._extract_auto_required_relations(start_params) + self._extract_auto_required_relations(prev_model=start_params.prev_model) self._include_auto_related_models() self._select_related.sort(key=lambda item: (-len(item), item)) @@ -266,7 +261,90 @@ class QuerySet: return expr - def filter(self, **kwargs: Any) -> "QuerySet": + def _determine_filter_target_table( + self, related_parts: List[str], select_related: List[str] + ) -> Tuple[List[str], str, "Model"]: + + table_prefix = "" + model_cls = self.model_cls + select_related = [relation for relation in select_related] + + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + previous_table = model_cls.__tablename__ + for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + manager = model_cls._orm_relationship_manager + table_prefix = manager.resolve_relation_join(previous_table, current_table) + model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + return select_related, table_prefix, model_cls + + def _compile_clause( + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, + ) -> sqlalchemy.sql.expression.TextClause: + for modifier, modifier_value in modifiers.items(): + clause.modifiers[modifier] = modifier_value + + clause_text = str( + clause.compile( + dialect=self.model_cls.__database__._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) + clause = text(clause_text) + return clause + + def _escape_characters_in_clause( + self, op: str, value: Union[str, "Model"] + ) -> Tuple[str, bool]: + has_escaped_character = False + + if op in ["contains", "icontains"]: + if isinstance(value, orm.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) + + has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value) + + if has_escaped_character: + # enable escape modifier + for char in self.ESCAPE_CHARACTERS: + value = value.replace(char, f"\\{char}") + value = f"%{value}%" + + return value, has_escaped_character + + @staticmethod + def _extract_operator_field_and_related( + parts: List[str], + ) -> Tuple[str, str, Optional[List]]: + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + return op, field_name, related_parts + + def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 filter_clauses = self.filter_clauses select_related = list(self._select_related) @@ -279,37 +357,21 @@ class QuerySet: if "__" in key: parts = key.split("__") - # Determine if we should treat the final part as a - # filter operator or as a related field. - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] + ( + op, + field_name, + related_parts, + ) = self._extract_operator_field_and_related(parts) model_cls = self.model_cls if related_parts: - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - previous_table = model_cls.__tablename__ - for part in related_parts: - current_table = model_cls.__model_fields__[ - part - ].to.__tablename__ - manager = model_cls._orm_relationship_manager - table_prefix = manager.resolve_relation_join( - previous_table, current_table - ) - model_cls = model_cls.__model_fields__[part].to - previous_table = current_table + ( + select_related, + table_prefix, + model_cls, + ) = self._determine_filter_target_table( + related_parts, select_related + ) table = model_cls.__table__ column = model_cls.__table__.columns[field_name] @@ -319,39 +381,20 @@ class QuerySet: column = self.table.columns[key] table = self.table - # Map the operation code onto SQLAlchemy's ColumnElement - # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement - op_attr = FILTER_OPERATORS[op] - has_escaped_character = False - - if op in ["contains", "icontains"]: - has_escaped_character = any( - c for c in self.ESCAPE_CHARACTERS if c in value - ) - if has_escaped_character: - # enable escape modifier - for char in self.ESCAPE_CHARACTERS: - value = value.replace(char, f"\\{char}") - value = f"%{value}%" + value, has_escaped_character = self._escape_characters_in_clause(op, value) if isinstance(value, orm.Model): value = value.pk + op_attr = FILTER_OPERATORS[op] clause = getattr(column, op_attr)(value) - clause.modifiers["escape"] = "\\" if has_escaped_character else None - - clause_text = str( - clause.compile( - dialect=self.model_cls.__database__._backend._dialect, - compile_kwargs={"literal_binds": True}, - ) + clause = self._compile_clause( + clause, + column, + table, + table_prefix, + modifiers={"escape": "\\" if has_escaped_character else None}, ) - alias = f"{table_prefix}_" if table_prefix else "" - aliased_name = f"{alias}{table.name}.{column.name}" - clause_text = clause_text.replace( - f"{table.name}.{column.name}", aliased_name - ) - clause = text(clause_text) filter_clauses.append(clause) @@ -425,7 +468,7 @@ class QuerySet: raise MultipleMatches() return self.model_cls.from_row(rows[0], select_related=self._select_related) - async def all(self, **kwargs: Any) -> List["Model"]: + async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 if kwargs: return await self.filter(**kwargs).all() diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 8889064..1c3d5dd 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -40,8 +40,14 @@ client = TestClient(app) def test_read_main(): - response = client.post("/items/", json={'name': 'test', 'id': 1, 'category': {'name': 'test cat'}}) + response = client.post( + "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} + ) assert response.status_code == 200 - assert response.json() == {'category': {'id': None, 'name': 'test cat'}, 'id': 1, 'name': 'test'} + assert response.json() == { + "category": {"id": None, "name": "test cat"}, + "id": 1, + "name": "test", + } item = Item(**response.json()) assert item.id == 1 diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index c222cfa..dfba2da 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -88,7 +88,7 @@ async def test_model_crud(): assert len(album.tracks) == 3 assert album.tracks[1].title == "Heart don't stand a chance" - album1 = await Album.objects.get(name='Malibu') + album1 = await Album.objects.get(name="Malibu") assert album1.pk == 1 assert album1.tracks is None @@ -127,7 +127,9 @@ async def test_fk_filter(): malibu = Album(name="Malibu%") await malibu.save() await Track.objects.create(album=malibu, title="The Bird", position=1) - await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) + await Track.objects.create( + album=malibu, title="Heart don't stand a chance", position=2 + ) await Track.objects.create(album=malibu, title="The Waters", position=3) fantasies = await Album.objects.create(name="Fantasies") @@ -135,12 +137,20 @@ async def test_fk_filter(): await Track.objects.create(album=fantasies, title="Sick Muse", position=2) await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) - tracks = await Track.objects.select_related("album").filter(album__name="Fantasies").all() + tracks = ( + await Track.objects.select_related("album") + .filter(album__name="Fantasies") + .all() + ) assert len(tracks) == 3 for track in tracks: assert track.album.name == "Fantasies" - tracks = await Track.objects.select_related("album").filter(album__name__icontains="fan").all() + tracks = ( + await Track.objects.select_related("album") + .filter(album__name__icontains="fan") + .all() + ) assert len(tracks) == 3 for track in tracks: assert track.album.name == "Fantasies" @@ -179,7 +189,11 @@ async def test_multiple_fk(): team = await Team.objects.create(org=other, name="Green Team") await Member.objects.create(team=team, email="e@example.org") - members = await Member.objects.select_related('team__org').filter(team__org__ident="ACME Ltd").all() + members = ( + await Member.objects.select_related("team__org") + .filter(team__org__ident="ACME Ltd") + .all() + ) assert len(members) == 4 for member in members: assert member.team.org.ident == "ACME Ltd" @@ -195,7 +209,11 @@ async def test_pk_filter(): tracks = await Track.objects.select_related("album").filter(pk=1).all() assert len(tracks) == 1 - tracks = await Track.objects.select_related("album").filter(position=2, album__name='Test').all() + tracks = ( + await Track.objects.select_related("album") + .filter(position=2, album__name="Test") + .all() + ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index 1b17457..e7f8e0b 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -1,5 +1,4 @@ import datetime -from typing import ClassVar import pydantic import pytest @@ -17,7 +16,7 @@ class ExampleModel(Model): __metadata__ = metadata test = fields.Integer(primary_key=True) test_string = fields.String(length=250) - test_text = fields.Text(default='') + test_text = fields.Text(default="") test_bool = fields.Boolean(nullable=False) test_float = fields.Float() test_datetime = fields.DateTime(default=datetime.datetime.now) @@ -28,33 +27,42 @@ class ExampleModel(Model): test_decimal = fields.Decimal(length=10, precision=2) -fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', - 'test_bigint', 'test_json'] +fields_to_check = [ + "test", + "test_text", + "test_string", + "test_datetime", + "test_date", + "test_text", + "test_float", + "test_bigint", + "test_json", +] class ExampleModel2(Model): __tablename__ = "example2" __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True) - test_string = fields.String('test_string2', length=250) + test = fields.Integer(name="test12", primary_key=True) + test_string = fields.String("test_string2", length=250) @pytest.fixture() def example(): - return ExampleModel(pk=1, test_string='test', test_bool=True) + return ExampleModel(pk=1, test_string="test", test_bool=True) def test_not_nullable_field_is_required(): with pytest.raises(pydantic.error_wrappers.ValidationError): - ExampleModel(test=1, test_string='test') + ExampleModel(test=1, test_string="test") def test_model_attribute_access(example): assert example.test == 1 - assert example.test_string == 'test' + assert example.test_string == "test" assert example.test_datetime.year == datetime.datetime.now().year assert example.test_date == datetime.date.today() - assert example.test_text == '' + assert example.test_text == "" assert example.test_float is None assert example.test_bigint == 0 assert example.test_json == {} @@ -63,7 +71,7 @@ def test_model_attribute_access(example): assert example.test == 12 example.new_attr = 12 - assert 'new_attr' in example.__dict__ + assert "new_attr" in example.__dict__ def test_primary_key_access_and_setting(example): @@ -87,44 +95,54 @@ def test_sqlalchemy_table_is_created(example): def test_double_column_name_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata - test_string = fields.String('test_string2', name='test_string2', length=250) + test_string = fields.String("test_string2", name="test_string2", length=250) def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata - test_string = fields.String(name='test_string2', length=250) + test_string = fields.String(name="test_string2", length=250) def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) + test = fields.Integer(name="test12", primary_key=True, pydantic_only=True) def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.Decimal(name='test12', primary_key=True) + test = fields.Decimal(name="test12", primary_key=True) def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.String(name='test12', primary_key=True) + test = fields.String(name="test12", primary_key=True) def test_json_conversion_in_model(): with pytest.raises(pydantic.ValidationError): - ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True) + ExampleModel( + test_json=datetime.datetime.now(), + test=1, + test_string="test", + test_bool=True, + ) diff --git a/tests/test_models.py b/tests/test_models.py index cf12c3e..8b0bf37 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,6 +3,7 @@ import pytest import sqlalchemy import orm +from orm.exceptions import QueryDefinitionError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -139,6 +140,13 @@ async def test_model_filter(): assert await products.count() == 3 +@pytest.mark.asyncio +async def test_wrong_query_contains_model(): + with pytest.raises(QueryDefinitionError): + product = Product(name="90%-Cotton", rating=2) + await Product.objects.filter(name__contains=product).count() + + @pytest.mark.asyncio async def test_model_exists(): async with database: @@ -175,7 +183,7 @@ async def test_model_limit_with_filter(): await User.objects.create(name="Tom") await User.objects.create(name="Tom") - assert len(await User.objects.limit(2).filter(name__iexact='Tom').all()) == 2 + assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 @pytest.mark.asyncio @@ -185,7 +193,7 @@ async def test_offset(): await User.objects.create(name="Jane") users = await User.objects.offset(1).limit(1).all() - assert users[0].name == 'Jane' + assert users[0].name == "Jane" @pytest.mark.asyncio diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index ea85409..66f6ad1 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -69,7 +69,7 @@ def create_test_database(): @pytest.fixture() async def init_relation(): - department = await Department.objects.create(id=1, name='Math Department') + department = await Department.objects.create(id=1, name="Math Department") class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") @@ -85,35 +85,41 @@ async def init_relation(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(init_relation): async with database: - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].name == 'Math' - assert classes[0].students[0].name == 'Jane' + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[0].name == "Jane" # related fields of main model are only populated by pk # unless there is a required foreign key somewhere along the way # since department is required for schoolclass it was pre loaded (again) # but you can load them anytime - assert classes[0].students[0].schoolclass.name == 'Math' + assert classes[0].students[0].schoolclass.name == "Math" assert classes[0].students[0].schoolclass.department.name is None await classes[0].students[0].schoolclass.department.load() - assert classes[0].students[0].schoolclass.department.name == 'Math Department' + assert classes[0].students[0].schoolclass.department.name == "Math Department" @pytest.mark.asyncio async def test_right_tables_join(init_relation): async with database: - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].teachers[0].category.name == 'Domestic' + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].teachers[0].category.name == "Domestic" assert classes[0].students[0].category.name is None await classes[0].students[0].category.load() - assert classes[0].students[0].category.name == 'Foreign' + assert classes[0].students[0].category.name == "Foreign" @pytest.mark.asyncio async def test_multiple_reverse_related_objects(init_relation): async with database: - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].name == 'Math' - assert classes[0].students[1].name == 'Jack' - assert classes[0].teachers[0].category.name == 'Domestic' + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[1].name == "Jack" + assert classes[0].teachers[0].category.name == "Domestic"