fix some complexity issues

This commit is contained in:
collerek
2020-08-09 10:58:36 +02:00
parent 22c4a0619c
commit fb5d03d64c
10 changed files with 244 additions and 134 deletions

BIN
.coverage

Binary file not shown.

View File

@ -1,5 +1,5 @@
[flake8] [flake8]
ignore = ANN101, ANN102, W503 ignore = ANN101, ANN102, W503, S101
max-complexity = 8 max-complexity = 8
max-line-length = 88 max-line-length = 88
exclude = p38venv,.pytest_cache exclude = p38venv,.pytest_cache

View File

@ -18,5 +18,9 @@ class MultipleMatches(AsyncOrmException):
pass pass
class QueryDefinitionError(AsyncOrmException):
pass
class RelationshipInstanceError(AsyncOrmException): class RelationshipInstanceError(AsyncOrmException):
pass pass

View File

@ -46,13 +46,11 @@ def sqlalchemy_columns_from_model_fields(
if field.primary_key: if field.primary_key:
pkname = field_name pkname = field_name
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
reverse_name = ( child_relation_name = (
field.related_name field.to.get_name(title=True) + "_" + name.lower() + "s"
or field.to.__name__.lower().title() + "_" + name.lower() + "s"
)
relation_name = (
name.lower().title() + "_" + field.to.__name__.lower()
) )
reverse_name = field.related_name or child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name()
relationship_manager.add_relation_type( relationship_manager.add_relation_type(
relation_name, reverse_name, field, tablename relation_name, reverse_name, field, tablename
) )
@ -241,6 +239,15 @@ class Model(list, metaclass=ModelMetaclass):
# def schema(cls, by_alias: bool = True): # pragma no cover # def schema(cls, by_alias: bool = True): # pragma no cover
# return cls.__pydantic_model__.schema(by_alias=by_alias) # return cls.__pydantic_model__.schema(by_alias=by_alias)
@classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str:
name = cls.__name__
if lower:
name = name.lower()
if title:
name = name.title()
return name
def is_conversion_to_json_needed(self, column_name: str) -> bool: def is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json return self.__model_fields__.get(column_name).__type__ == pydantic.Json
@ -256,7 +263,7 @@ class Model(list, metaclass=ModelMetaclass):
def pk_column(self) -> sqlalchemy.Column: def pk_column(self) -> sqlalchemy.Column:
return self.__table__.primary_key.columns.values()[0] return self.__table__.primary_key.columns.values()[0]
def dict(self) -> Dict: def dict(self) -> Dict: # noqa: A003
dict_instance = self.values.dict() dict_instance = self.values.dict()
for field in self.extract_related_names(): for field in self.extract_related_names():
nested_model = getattr(self, field) nested_model = getattr(self, field)

View File

@ -1,10 +1,20 @@
from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union from typing import (
Any,
Dict,
List,
NamedTuple,
Optional,
TYPE_CHECKING,
Tuple,
Type,
Union,
)
import databases import databases
import orm import orm
from orm import ForeignKey from orm import ForeignKey
from orm.exceptions import MultipleMatches, NoMatch from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError
from orm.fields import BaseField from orm.fields import BaseField
import sqlalchemy import sqlalchemy
@ -80,18 +90,11 @@ class QuerySet:
return text(f"{name} {alias}_{name}") return text(f"{name} {alias}_{name}")
def on_clause( def on_clause(
self, self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
from_table: str,
to_table: str,
previous_alias: str,
alias: str,
to_key: str,
from_key: str,
) -> text: ) -> text:
return text( left_part = f"{alias}_{to_clause}"
f"{alias}_{to_table}.{to_key}=" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' return text(f"{left_part}={right_part}")
)
def _build_join_parameters( def _build_join_parameters(
self, part: str, join_params: JoinParameters self, part: str, join_params: JoinParameters
@ -118,12 +121,10 @@ class QuerySet:
from_key = part from_key = part
on_clause = self.on_clause( on_clause = self.on_clause(
join_params.from_table, previous_alias=join_params.previous_alias,
to_table, alias=alias,
join_params.previous_alias, from_clause=f"{join_params.from_table}.{from_key}",
alias, to_clause=f"{to_table}.{to_key}",
to_key,
from_key,
) )
target_table = self.prefixed_table_name(alias, to_table) target_table = self.prefixed_table_name(alias, to_table)
self.select_from = sqlalchemy.sql.outerjoin( self.select_from = sqlalchemy.sql.outerjoin(
@ -159,12 +160,12 @@ class QuerySet:
def _extract_auto_required_relations( def _extract_auto_required_relations(
self, self,
join_params: JoinParameters, prev_model: Type["Model"],
rel_part: str = "", rel_part: str = "",
nested: bool = False, nested: bool = False,
parent_virtual: bool = False, parent_virtual: bool = False,
) -> None: ) -> None:
for field_name, field in join_params.prev_model.__model_fields__.items(): for field_name, field in prev_model.__model_fields__.items():
if self._field_is_a_foreign_key_and_no_circular_reference( if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part field, field_name, rel_part
): ):
@ -176,14 +177,8 @@ class QuerySet:
elif self._field_qualifies_to_deeper_search( elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part field, parent_virtual, nested, rel_part
): ):
join_params = JoinParameters(
field.to,
join_params.previous_alias,
join_params.from_table,
join_params.prev_model,
)
self._extract_auto_required_relations( self._extract_auto_required_relations(
join_params=join_params, prev_model=field.to,
rel_part=rel_part, rel_part=rel_part,
nested=True, nested=True,
parent_virtual=field.virtual, parent_virtual=field.virtual,
@ -244,7 +239,7 @@ class QuerySet:
start_params = JoinParameters( start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls self.model_cls, "", self.table.name, self.model_cls
) )
self._extract_auto_required_relations(start_params) self._extract_auto_required_relations(prev_model=start_params.prev_model)
self._include_auto_related_models() self._include_auto_related_models()
self._select_related.sort(key=lambda item: (-len(item), item)) self._select_related.sort(key=lambda item: (-len(item), item))
@ -266,7 +261,90 @@ class QuerySet:
return expr return expr
def filter(self, **kwargs: Any) -> "QuerySet": def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
) -> Tuple[List[str], str, "Model"]:
table_prefix = ""
model_cls = self.model_cls
select_related = [relation for relation in select_related]
# Add any implied select_related
related_str = "__".join(related_parts)
if related_str not in select_related:
select_related.append(related_str)
# Walk the relationships to the actual model class
# against which the comparison is being made.
previous_table = model_cls.__tablename__
for part in related_parts:
current_table = model_cls.__model_fields__[part].to.__tablename__
manager = model_cls._orm_relationship_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table)
model_cls = model_cls.__model_fields__[part].to
previous_table = current_table
return select_related, table_prefix, model_cls
def _compile_clause(
self,
clause: sqlalchemy.sql.expression.BinaryExpression,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
modifiers: Dict,
) -> sqlalchemy.sql.expression.TextClause:
for modifier, modifier_value in modifiers.items():
clause.modifiers[modifier] = modifier_value
clause_text = str(
clause.compile(
dialect=self.model_cls.__database__._backend._dialect,
compile_kwargs={"literal_binds": True},
)
)
alias = f"{table_prefix}_" if table_prefix else ""
aliased_name = f"{alias}{table.name}.{column.name}"
clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name)
clause = text(clause_text)
return clause
def _escape_characters_in_clause(
self, op: str, value: Union[str, "Model"]
) -> Tuple[str, bool]:
has_escaped_character = False
if op in ["contains", "icontains"]:
if isinstance(value, orm.Model):
raise QueryDefinitionError(
"You cannot use contains and icontains with instance of the Model"
)
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f"\\{char}")
value = f"%{value}%"
return value, has_escaped_character
@staticmethod
def _extract_operator_field_and_related(
parts: List[str],
) -> Tuple[str, str, Optional[List]]:
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
return op, field_name, related_parts
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
filter_clauses = self.filter_clauses filter_clauses = self.filter_clauses
select_related = list(self._select_related) select_related = list(self._select_related)
@ -279,37 +357,21 @@ class QuerySet:
if "__" in key: if "__" in key:
parts = key.split("__") parts = key.split("__")
# Determine if we should treat the final part as a (
# filter operator or as a related field. op,
if parts[-1] in FILTER_OPERATORS: field_name,
op = parts[-1] related_parts,
field_name = parts[-2] ) = self._extract_operator_field_and_related(parts)
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
model_cls = self.model_cls model_cls = self.model_cls
if related_parts: if related_parts:
# Add any implied select_related (
related_str = "__".join(related_parts) select_related,
if related_str not in select_related: table_prefix,
select_related.append(related_str) model_cls,
) = self._determine_filter_target_table(
# Walk the relationships to the actual model class related_parts, select_related
# against which the comparison is being made. )
previous_table = model_cls.__tablename__
for part in related_parts:
current_table = model_cls.__model_fields__[
part
].to.__tablename__
manager = model_cls._orm_relationship_manager
table_prefix = manager.resolve_relation_join(
previous_table, current_table
)
model_cls = model_cls.__model_fields__[part].to
previous_table = current_table
table = model_cls.__table__ table = model_cls.__table__
column = model_cls.__table__.columns[field_name] column = model_cls.__table__.columns[field_name]
@ -319,39 +381,20 @@ class QuerySet:
column = self.table.columns[key] column = self.table.columns[key]
table = self.table table = self.table
# Map the operation code onto SQLAlchemy's ColumnElement value, has_escaped_character = self._escape_characters_in_clause(op, value)
# https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
op_attr = FILTER_OPERATORS[op]
has_escaped_character = False
if op in ["contains", "icontains"]:
has_escaped_character = any(
c for c in self.ESCAPE_CHARACTERS if c in value
)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f"\\{char}")
value = f"%{value}%"
if isinstance(value, orm.Model): if isinstance(value, orm.Model):
value = value.pk value = value.pk
op_attr = FILTER_OPERATORS[op]
clause = getattr(column, op_attr)(value) clause = getattr(column, op_attr)(value)
clause.modifiers["escape"] = "\\" if has_escaped_character else None clause = self._compile_clause(
clause,
clause_text = str( column,
clause.compile( table,
dialect=self.model_cls.__database__._backend._dialect, table_prefix,
compile_kwargs={"literal_binds": True}, modifiers={"escape": "\\" if has_escaped_character else None},
)
) )
alias = f"{table_prefix}_" if table_prefix else ""
aliased_name = f"{alias}{table.name}.{column.name}"
clause_text = clause_text.replace(
f"{table.name}.{column.name}", aliased_name
)
clause = text(clause_text)
filter_clauses.append(clause) filter_clauses.append(clause)
@ -425,7 +468,7 @@ class QuerySet:
raise MultipleMatches() raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related) return self.model_cls.from_row(rows[0], select_related=self._select_related)
async def all(self, **kwargs: Any) -> List["Model"]: async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
if kwargs: if kwargs:
return await self.filter(**kwargs).all() return await self.filter(**kwargs).all()

View File

@ -40,8 +40,14 @@ client = TestClient(app)
def test_read_main(): def test_read_main():
response = client.post("/items/", json={'name': 'test', 'id': 1, 'category': {'name': 'test cat'}}) response = client.post(
"/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}}
)
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {'category': {'id': None, 'name': 'test cat'}, 'id': 1, 'name': 'test'} assert response.json() == {
"category": {"id": None, "name": "test cat"},
"id": 1,
"name": "test",
}
item = Item(**response.json()) item = Item(**response.json())
assert item.id == 1 assert item.id == 1

View File

@ -88,7 +88,7 @@ async def test_model_crud():
assert len(album.tracks) == 3 assert len(album.tracks) == 3
assert album.tracks[1].title == "Heart don't stand a chance" assert album.tracks[1].title == "Heart don't stand a chance"
album1 = await Album.objects.get(name='Malibu') album1 = await Album.objects.get(name="Malibu")
assert album1.pk == 1 assert album1.pk == 1
assert album1.tracks is None assert album1.tracks is None
@ -127,7 +127,9 @@ async def test_fk_filter():
malibu = Album(name="Malibu%") malibu = Album(name="Malibu%")
await malibu.save() await malibu.save()
await Track.objects.create(album=malibu, title="The Bird", position=1) await Track.objects.create(album=malibu, title="The Bird", position=1)
await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) await Track.objects.create(
album=malibu, title="Heart don't stand a chance", position=2
)
await Track.objects.create(album=malibu, title="The Waters", position=3) await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies") fantasies = await Album.objects.create(name="Fantasies")
@ -135,12 +137,20 @@ async def test_fk_filter():
await Track.objects.create(album=fantasies, title="Sick Muse", position=2) await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
tracks = await Track.objects.select_related("album").filter(album__name="Fantasies").all() tracks = (
await Track.objects.select_related("album")
.filter(album__name="Fantasies")
.all()
)
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
assert track.album.name == "Fantasies" assert track.album.name == "Fantasies"
tracks = await Track.objects.select_related("album").filter(album__name__icontains="fan").all() tracks = (
await Track.objects.select_related("album")
.filter(album__name__icontains="fan")
.all()
)
assert len(tracks) == 3 assert len(tracks) == 3
for track in tracks: for track in tracks:
assert track.album.name == "Fantasies" assert track.album.name == "Fantasies"
@ -179,7 +189,11 @@ async def test_multiple_fk():
team = await Team.objects.create(org=other, name="Green Team") team = await Team.objects.create(org=other, name="Green Team")
await Member.objects.create(team=team, email="e@example.org") await Member.objects.create(team=team, email="e@example.org")
members = await Member.objects.select_related('team__org').filter(team__org__ident="ACME Ltd").all() members = (
await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd")
.all()
)
assert len(members) == 4 assert len(members) == 4
for member in members: for member in members:
assert member.team.org.ident == "ACME Ltd" assert member.team.org.ident == "ACME Ltd"
@ -195,7 +209,11 @@ async def test_pk_filter():
tracks = await Track.objects.select_related("album").filter(pk=1).all() tracks = await Track.objects.select_related("album").filter(pk=1).all()
assert len(tracks) == 1 assert len(tracks) == 1
tracks = await Track.objects.select_related("album").filter(position=2, album__name='Test').all() tracks = (
await Track.objects.select_related("album")
.filter(position=2, album__name="Test")
.all()
)
assert len(tracks) == 1 assert len(tracks) == 1

View File

@ -1,5 +1,4 @@
import datetime import datetime
from typing import ClassVar
import pydantic import pydantic
import pytest import pytest
@ -17,7 +16,7 @@ class ExampleModel(Model):
__metadata__ = metadata __metadata__ = metadata
test = fields.Integer(primary_key=True) test = fields.Integer(primary_key=True)
test_string = fields.String(length=250) test_string = fields.String(length=250)
test_text = fields.Text(default='') test_text = fields.Text(default="")
test_bool = fields.Boolean(nullable=False) test_bool = fields.Boolean(nullable=False)
test_float = fields.Float() test_float = fields.Float()
test_datetime = fields.DateTime(default=datetime.datetime.now) test_datetime = fields.DateTime(default=datetime.datetime.now)
@ -28,33 +27,42 @@ class ExampleModel(Model):
test_decimal = fields.Decimal(length=10, precision=2) test_decimal = fields.Decimal(length=10, precision=2)
fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', fields_to_check = [
'test_bigint', 'test_json'] "test",
"test_text",
"test_string",
"test_datetime",
"test_date",
"test_text",
"test_float",
"test_bigint",
"test_json",
]
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example2" __tablename__ = "example2"
__metadata__ = metadata __metadata__ = metadata
test = fields.Integer(name='test12', primary_key=True) test = fields.Integer(name="test12", primary_key=True)
test_string = fields.String('test_string2', length=250) test_string = fields.String("test_string2", length=250)
@pytest.fixture() @pytest.fixture()
def example(): def example():
return ExampleModel(pk=1, test_string='test', test_bool=True) return ExampleModel(pk=1, test_string="test", test_bool=True)
def test_not_nullable_field_is_required(): def test_not_nullable_field_is_required():
with pytest.raises(pydantic.error_wrappers.ValidationError): with pytest.raises(pydantic.error_wrappers.ValidationError):
ExampleModel(test=1, test_string='test') ExampleModel(test=1, test_string="test")
def test_model_attribute_access(example): def test_model_attribute_access(example):
assert example.test == 1 assert example.test == 1
assert example.test_string == 'test' assert example.test_string == "test"
assert example.test_datetime.year == datetime.datetime.now().year assert example.test_datetime.year == datetime.datetime.now().year
assert example.test_date == datetime.date.today() assert example.test_date == datetime.date.today()
assert example.test_text == '' assert example.test_text == ""
assert example.test_float is None assert example.test_float is None
assert example.test_bigint == 0 assert example.test_bigint == 0
assert example.test_json == {} assert example.test_json == {}
@ -63,7 +71,7 @@ def test_model_attribute_access(example):
assert example.test == 12 assert example.test == 12
example.new_attr = 12 example.new_attr = 12
assert 'new_attr' in example.__dict__ assert "new_attr" in example.__dict__
def test_primary_key_access_and_setting(example): def test_primary_key_access_and_setting(example):
@ -87,44 +95,54 @@ def test_sqlalchemy_table_is_created(example):
def test_double_column_name_in_model_definition(): def test_double_column_name_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example3" __tablename__ = "example3"
__metadata__ = metadata __metadata__ = metadata
test_string = fields.String('test_string2', name='test_string2', length=250) test_string = fields.String("test_string2", name="test_string2", length=250)
def test_no_pk_in_model_definition(): def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example3" __tablename__ = "example3"
__metadata__ = metadata __metadata__ = metadata
test_string = fields.String(name='test_string2', length=250) test_string = fields.String(name="test_string2", length=250)
def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" __tablename__ = "example4"
__metadata__ = metadata __metadata__ = metadata
test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) test = fields.Integer(name="test12", primary_key=True, pydantic_only=True)
def test_decimal_error_in_model_definition(): def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" __tablename__ = "example4"
__metadata__ = metadata __metadata__ = metadata
test = fields.Decimal(name='test12', primary_key=True) test = fields.Decimal(name="test12", primary_key=True)
def test_string_error_in_model_definition(): def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
__tablename__ = "example4" __tablename__ = "example4"
__metadata__ = metadata __metadata__ = metadata
test = fields.String(name='test12', primary_key=True) test = fields.String(name="test12", primary_key=True)
def test_json_conversion_in_model(): def test_json_conversion_in_model():
with pytest.raises(pydantic.ValidationError): with pytest.raises(pydantic.ValidationError):
ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True) ExampleModel(
test_json=datetime.datetime.now(),
test=1,
test_string="test",
test_bool=True,
)

View File

@ -3,6 +3,7 @@ import pytest
import sqlalchemy import sqlalchemy
import orm import orm
from orm.exceptions import QueryDefinitionError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -139,6 +140,13 @@ async def test_model_filter():
assert await products.count() == 3 assert await products.count() == 3
@pytest.mark.asyncio
async def test_wrong_query_contains_model():
with pytest.raises(QueryDefinitionError):
product = Product(name="90%-Cotton", rating=2)
await Product.objects.filter(name__contains=product).count()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_exists(): async def test_model_exists():
async with database: async with database:
@ -175,7 +183,7 @@ async def test_model_limit_with_filter():
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
await User.objects.create(name="Tom") await User.objects.create(name="Tom")
assert len(await User.objects.limit(2).filter(name__iexact='Tom').all()) == 2 assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@ -185,7 +193,7 @@ async def test_offset():
await User.objects.create(name="Jane") await User.objects.create(name="Jane")
users = await User.objects.offset(1).limit(1).all() users = await User.objects.offset(1).limit(1).all()
assert users[0].name == 'Jane' assert users[0].name == "Jane"
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -69,7 +69,7 @@ def create_test_database():
@pytest.fixture() @pytest.fixture()
async def init_relation(): async def init_relation():
department = await Department.objects.create(id=1, name='Math Department') department = await Department.objects.create(id=1, name="Math Department")
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic") category2 = await Category.objects.create(name="Domestic")
@ -85,35 +85,41 @@ async def init_relation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(init_relation): async def test_model_multiple_instances_of_same_table_in_schema(init_relation):
async with database: async with database:
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() classes = await SchoolClass.objects.select_related(
assert classes[0].name == 'Math' ["teachers__category", "students"]
assert classes[0].students[0].name == 'Jane' ).all()
assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane"
# related fields of main model are only populated by pk # related fields of main model are only populated by pk
# unless there is a required foreign key somewhere along the way # unless there is a required foreign key somewhere along the way
# since department is required for schoolclass it was pre loaded (again) # since department is required for schoolclass it was pre loaded (again)
# but you can load them anytime # but you can load them anytime
assert classes[0].students[0].schoolclass.name == 'Math' assert classes[0].students[0].schoolclass.name == "Math"
assert classes[0].students[0].schoolclass.department.name is None assert classes[0].students[0].schoolclass.department.name is None
await classes[0].students[0].schoolclass.department.load() await classes[0].students[0].schoolclass.department.load()
assert classes[0].students[0].schoolclass.department.name == 'Math Department' assert classes[0].students[0].schoolclass.department.name == "Math Department"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(init_relation): async def test_right_tables_join(init_relation):
async with database: async with database:
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() classes = await SchoolClass.objects.select_related(
assert classes[0].teachers[0].category.name == 'Domestic' ["teachers__category", "students"]
).all()
assert classes[0].teachers[0].category.name == "Domestic"
assert classes[0].students[0].category.name is None assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load() await classes[0].students[0].category.load()
assert classes[0].students[0].category.name == 'Foreign' assert classes[0].students[0].category.name == "Foreign"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_reverse_related_objects(init_relation): async def test_multiple_reverse_related_objects(init_relation):
async with database: async with database:
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() classes = await SchoolClass.objects.select_related(
assert classes[0].name == 'Math' ["teachers__category", "students"]
assert classes[0].students[1].name == 'Jack' ).all()
assert classes[0].teachers[0].category.name == 'Domestic' assert classes[0].name == "Math"
assert classes[0].students[1].name == "Jack"
assert classes[0].teachers[0].category.name == "Domestic"