diff --git a/.coverage b/.coverage index 458f78c..5c9ea5d 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/models/model.py b/ormar/models/model.py index eeff399..fd61ba6 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -89,7 +89,7 @@ class Model(NewBaseModel): return item @classmethod - def extract_prefixed_table_columns( # noqa CCR001 + def extract_prefixed_table_columns( # noqa CCR001 cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str ) -> dict: for column in cls.Meta.table.columns: diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 2cd86a2..9052b71 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar +from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union import ormar from ormar.exceptions import RelationshipInstanceError @@ -94,7 +94,9 @@ class ModelTableProxy: return name @staticmethod - def resolve_relation_field(item: "Model", related: "Model") -> Type[Field]: + def resolve_relation_field( + item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] + ) -> Type[Field]: name = ModelTableProxy.resolve_relation_name(item, related) to_field = item.Meta.model_fields.get(name) if not to_field: # pragma no cover diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index b1e8830..22578a7 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -4,8 +4,6 @@ import sqlalchemy from sqlalchemy import text import ormar # noqa I100 -from ormar.fields import BaseField -from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.many_to_many import ManyToManyField from ormar.queryset import FilterQuery, LimitQuery, OffsetQuery, OrderQuery from ormar.relations.alias_manager import AliasManager @@ -95,11 +93,6 @@ class Query: right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") - def _is_target_relation_key( - self, field: BaseField, target_model: Type["Model"] - ) -> bool: - return issubclass(field, ForeignKeyField) and field.to.Meta == target_model.Meta - def _build_join_parameters( self, part: str, join_params: JoinParameters, is_multi: bool = False ) -> JoinParameters: @@ -157,14 +150,10 @@ class Query: part: str, ) -> Tuple[str, str]: if join_params.prev_model.Meta.model_fields[part].virtual or is_multi: - to_key = next( - ( - v - for k, v in model_cls.Meta.model_fields.items() - if self._is_target_relation_key(v, join_params.prev_model) - ), - None, - ).name + to_field = model_cls.resolve_relation_field( + model_cls, join_params.prev_model + ) + to_key = to_field.name from_key = model_cls.Meta.pkname else: to_key = model_cls.Meta.pkname diff --git a/tests/settings.py b/tests/settings.py index 4cf7350..6d89f4e 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -2,10 +2,7 @@ import os import databases -assert "DATABASE_URL" in os.environ, "DATABASE_URL is not set." - -DATABASE_URL = os.environ['DATABASE_URL'] +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") database_url = databases.DatabaseURL(DATABASE_URL) if database_url.scheme == "postgresql+aiopg": # pragma no cover DATABASE_URL = str(database_url.replace(driver=None)) -DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db")