some cleanup of unused relations code, introduced caching of related_names and props on model, set profiling

This commit is contained in:
collerek
2020-11-12 08:11:40 +01:00
parent 1242e5d600
commit e743286008
11 changed files with 95 additions and 466 deletions

1
.gitignore vendored
View File

@ -9,3 +9,4 @@ test.db
dist dist
/ormar.egg-info/ /ormar.egg-info/
site site
profile.py

View File

@ -12,11 +12,17 @@ from typing import (
Union, Union,
) )
import ormar
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
try:
import orjson as json
except ImportError: # pragma: nocover
import json # type: ignore
import ormar
from ormar.fields import BaseField, ManyToManyField from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta, expand_reverse_relationships from ormar.models.metaclass import ModelMeta
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -30,6 +36,8 @@ Field = TypeVar("Field", bound=BaseField)
class ModelTableProxy: class ModelTableProxy:
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
Meta: ModelMeta Meta: ModelMeta
_related_names: Set
_related_names_hash: Union[str, bytes]
def dict(self): # noqa A003 def dict(self): # noqa A003
raise NotImplementedError # pragma no cover raise NotImplementedError # pragma no cover
@ -88,10 +96,17 @@ class ModelTableProxy:
@classmethod @classmethod
def extract_related_names(cls) -> Set: def extract_related_names(cls) -> Set:
if isinstance(cls._related_names_hash, (str, bytes)):
return cls._related_names
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField): if inspect.isclass(field) and issubclass(field, ForeignKeyField):
related_names.add(name) related_names.add(name)
cls._related_names_hash = json.dumps(list(cls.Meta.model_fields.keys()))
cls._related_names = related_names
return related_names return related_names
@classmethod @classmethod
@ -99,10 +114,10 @@ class ModelTableProxy:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if ( if (
inspect.isclass(field) inspect.isclass(field)
and issubclass(field, ForeignKeyField) and issubclass(field, ForeignKeyField)
and not issubclass(field, ManyToManyField) and not issubclass(field, ManyToManyField)
and not field.virtual and not field.virtual
): ):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -114,9 +129,9 @@ class ModelTableProxy:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if ( if (
inspect.isclass(field) inspect.isclass(field)
and issubclass(field, ForeignKeyField) and issubclass(field, ForeignKeyField)
and field.nullable and field.nullable
): ):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -136,9 +151,8 @@ class ModelTableProxy:
@staticmethod @staticmethod
def resolve_relation_name( # noqa CCR001 def resolve_relation_name( # noqa CCR001
item: Union["NewBaseModel", Type["NewBaseModel"]], item: Union["NewBaseModel", Type["NewBaseModel"]],
related: Union["NewBaseModel", Type["NewBaseModel"]], related: Union["NewBaseModel", Type["NewBaseModel"]]
register_missing: bool = True,
) -> str: ) -> str:
for name, field in item.Meta.model_fields.items(): for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField): if issubclass(field, ForeignKeyField):
@ -147,12 +161,6 @@ class ModelTableProxy:
# so we need to compare Meta too as this one is copied as is # so we need to compare Meta too as this one is copied as is
if field.to == related.__class__ or field.to.Meta == related.Meta: if field.to == related.__class__ or field.to.Meta == related.Meta:
return name return name
# fallback for not registered relation
if register_missing: # pragma nocover
expand_reverse_relationships(related.__class__) # type: ignore
return ModelTableProxy.resolve_relation_name(
item, related, register_missing=False
)
raise ValueError( raise ValueError(
f"No relation between {item.get_name()} and {related.get_name()}" f"No relation between {item.get_name()} and {related.get_name()}"
@ -160,8 +168,8 @@ class ModelTableProxy:
@staticmethod @staticmethod
def resolve_relation_field( def resolve_relation_field(
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
) -> Union[Type[BaseField], Type[ForeignKeyField]]: ) -> Type[BaseField]:
name = ModelTableProxy.resolve_relation_name(item, related) name = ModelTableProxy.resolve_relation_name(item, related)
to_field = item.Meta.model_fields.get(name) to_field = item.Meta.model_fields.get(name)
if not to_field: # pragma no cover if not to_field: # pragma no cover
@ -207,12 +215,12 @@ class ModelTableProxy:
for field in one.Meta.model_fields.keys(): for field in one.Meta.model_fields.keys():
current_field = getattr(one, field) current_field = getattr(one, field)
if isinstance(current_field, list) and not isinstance( if isinstance(current_field, list) and not isinstance(
current_field, ormar.Model current_field, ormar.Model
): ):
setattr(other, field, current_field + getattr(other, field)) setattr(other, field, current_field + getattr(other, field))
elif ( elif (
isinstance(current_field, ormar.Model) isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk and current_field.pk == getattr(other, field).pk
): ):
setattr( setattr(
other, other,
@ -223,7 +231,7 @@ class ModelTableProxy:
@staticmethod @staticmethod
def _populate_pk_column( def _populate_pk_column(
model: Type["Model"], columns: List[str], use_alias: bool = False, model: Type["Model"], columns: List[str], use_alias: bool = False,
) -> List[str]: ) -> List[str]:
pk_alias = ( pk_alias = (
model.get_column_alias(model.Meta.pkname) model.get_column_alias(model.Meta.pkname)
@ -236,10 +244,10 @@ class ModelTableProxy:
@staticmethod @staticmethod
def own_table_columns( def own_table_columns(
model: Type["Model"], model: Type["Model"],
fields: Optional[Union[Set, Dict]], fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]], exclude_fields: Optional[Union[Set, Dict]],
use_alias: bool = False, use_alias: bool = False,
) -> List[str]: ) -> List[str]:
columns = [ columns = [
model.get_column_name_from_alias(col.name) if not use_alias else col.name model.get_column_name_from_alias(col.name) if not use_alias else col.name

View File

@ -9,7 +9,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
TYPE_CHECKING, Set, TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -43,7 +43,7 @@ if TYPE_CHECKING: # pragma no cover
class NewBaseModel( class NewBaseModel(
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
): ):
__slots__ = ("_orm_id", "_orm_saved", "_orm") __slots__ = ("_orm_id", "_orm_saved", "_orm", "_related_names", "_related_names_hash", "_props")
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, Type[BaseField]] __model_fields__: Dict[str, Type[BaseField]]
@ -56,6 +56,10 @@ class NewBaseModel(
__database__: databases.Database __database__: databases.Database
_orm_relationship_manager: AliasManager _orm_relationship_manager: AliasManager
_orm: RelationsManager _orm: RelationsManager
_orm_saved: bool
_related_names: Set
_related_names_hash: str
_props: List[str]
Meta: ModelMeta Meta: ModelMeta
# noinspection PyMissingConstructor # noinspection PyMissingConstructor
@ -107,7 +111,7 @@ class NewBaseModel(
) )
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
if name in ("_orm_id", "_orm_saved", "_orm"): if name in ("_orm_id", "_orm_saved", "_orm", "_related_names", "_props"):
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == "pk": elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
@ -126,12 +130,12 @@ class NewBaseModel(
super().__setattr__(name, value) super().__setattr__(name, value)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): if item in ("_orm_id", "_orm_saved", "_orm", "__fields__", "_related_names", "_props"):
return object.__getattribute__(self, item) return object.__getattribute__(self, item)
if item != "extract_related_names" and item in self.extract_related_names():
return self._extract_related_model_instead_of_field(item)
if item == "pk": if item == "pk":
return self.__dict__.get(self.Meta.pkname, None) return self.__dict__.get(self.Meta.pkname, None)
if item != "extract_related_names" and item in self.extract_related_names():
return self._extract_related_model_instead_of_field(item)
if item != "__fields__" and item in self.__fields__: if item != "__fields__" and item in self.__fields__:
value = self.__dict__.get(item, None) value = self.__dict__.get(item, None)
value = self._convert_json(item, value, "loads") value = self._convert_json(item, value, "loads")
@ -186,12 +190,16 @@ class NewBaseModel(
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
) -> List[str]: ) -> List[str]:
props = [ if isinstance(cls._props, list):
prop props = cls._props
for prop in dir(cls) else:
if isinstance(getattr(cls, prop), property) props = [
and prop not in ("__values__", "__fields__", "fields", "pk_column") prop
] for prop in dir(cls)
if isinstance(getattr(cls, prop), property)
and prop not in ("__values__", "__fields__", "fields", "pk_column")
]
cls._props = props
if include: if include:
props = [prop for prop in props if prop in include] props = [prop for prop in props if prop in include]
if exclude: if exclude:

View File

@ -3,7 +3,6 @@ from ormar.relations.relation import Relation, RelationType
from ormar.relations.relation_manager import RelationsManager from ormar.relations.relation_manager import RelationsManager
from ormar.relations.utils import ( from ormar.relations.utils import (
get_relations_sides_and_names, get_relations_sides_and_names,
register_missing_relation,
) )
__all__ = [ __all__ = [
@ -11,6 +10,5 @@ __all__ = [
"Relation", "Relation",
"RelationsManager", "RelationsManager",
"RelationType", "RelationType",
"register_missing_relation",
"get_relations_sides_and_names", "get_relations_sides_and_names",
] ]

View File

@ -7,7 +7,6 @@ from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.relation import Relation, RelationType from ormar.relations.relation import Relation, RelationType
from ormar.relations.utils import ( from ormar.relations.utils import (
get_relations_sides_and_names, get_relations_sides_and_names,
register_missing_relation,
) )
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -42,8 +41,6 @@ class RelationsManager:
to=field.to, to=field.to,
through=getattr(field, "through", None), through=getattr(field, "through", None),
) )
if field.name not in self._related_names:
self._related_names.append(field.name)
def __contains__(self, item: str) -> bool: def __contains__(self, item: str) -> bool:
return item in self._related_names return item in self._related_names
@ -69,9 +66,10 @@ class RelationsManager:
) )
parent_relation = parent._orm._get(child_name) parent_relation = parent._orm._get(child_name)
if not parent_relation: if parent_relation:
parent_relation = register_missing_relation(parent, child, child_name) # print('missing', child_name)
parent_relation.add(child) # type: ignore # parent_relation = register_missing_relation(parent, child, child_name)
parent_relation.add(child) # type: ignore
child_relation = child._orm._get(to_name) child_relation = child._orm._get(to_name)
if child_relation: if child_relation:

View File

@ -72,6 +72,4 @@ class RelationProxy(list):
if self.relation._type == ormar.RelationType.MULTIPLE: if self.relation._type == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item) await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner) rel_name = item.resolve_relation_name(item, self._owner)
if rel_name not in item._orm: # pragma nocover
item._orm._add_relation(item.Meta.model_fields[rel_name])
setattr(item, rel_name, self._owner) setattr(item, rel_name, self._owner)

View File

@ -1,32 +1,19 @@
from typing import Optional, TYPE_CHECKING, Tuple, Type from typing import TYPE_CHECKING, Tuple, Type
from weakref import proxy from weakref import proxy
import ormar
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.relations import Relation
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
def register_missing_relation(
parent: "Model", child: "Model", child_name: str
) -> Optional[Relation]:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
return parent_relation
def get_relations_sides_and_names( def get_relations_sides_and_names(
to_field: Type[BaseField], to_field: Type[BaseField],
parent: "Model", parent: "Model",
child: "Model", child: "Model",
child_name: str, child_name: str,
virtual: bool, virtual: bool,
) -> Tuple["Model", "Model", str, str]: ) -> Tuple["Model", "Model", str, str]:
to_name = to_field.name to_name = to_field.name
if issubclass(to_field, ManyToManyField): if issubclass(to_field, ManyToManyField):

View File

@ -4,6 +4,7 @@ databases[mysql]
pydantic pydantic
sqlalchemy sqlalchemy
typing_extensions typing_extensions
orjson
# Async database drivers # Async database drivers
aiomysql aiomysql
@ -34,3 +35,6 @@ flake8-variables-names
flake8-cognitive-complexity flake8-cognitive-complexity
flake8-functions flake8-functions
flake8-expression-complexity flake8-expression-complexity
# Performance testing
yappi

View File

@ -42,7 +42,7 @@ setup(
version=get_version(PACKAGE), version=get_version(PACKAGE),
url=URL, url=URL,
license="MIT", license="MIT",
description="An simple async ORM with fastapi in mind and pydantic validation.", description="A simple async ORM with fastapi in mind and pydantic validation.",
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'], keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'],
@ -56,6 +56,7 @@ setup(
"postgresql": ["asyncpg", "psycopg2"], "postgresql": ["asyncpg", "psycopg2"],
"mysql": ["aiomysql", "pymysql"], "mysql": ["aiomysql", "pymysql"],
"sqlite": ["aiosqlite"], "sqlite": ["aiosqlite"],
"orjson": ["orjson"]
}, },
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",

View File

@ -1,374 +0,0 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
class Cover(ormar.Model):
class Meta:
tablename = "covers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures")
title: str = ormar.String(max_length=100)
class Organisation(ormar.Model):
class Meta:
tablename = "org"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
class Team(ormar.Model):
class Meta:
tablename = "teams"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
org: Optional[Organisation] = ormar.ForeignKey(Organisation)
name: str = ormar.String(max_length=100)
class Member(ormar.Model):
class Meta:
tablename = "members"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
team: Optional[Team] = ormar.ForeignKey(Team)
email: str = ormar.String(max_length=100)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_wrong_query_foreign_key_type():
async with database:
with pytest.raises(RelationshipInstanceError):
Track(title="The Error", album="wrong_pk_type")
@pytest.mark.asyncio
async def test_setting_explicitly_empty_relation():
async with database:
track = Track(album=None, title="The Bird", position=1)
assert track.album is None
@pytest.mark.asyncio
async def test_related_name():
async with database:
async with database.transaction(force_rollback=True):
album = await Album.objects.create(name="Vanilla")
await Cover.objects.create(album=album, title="The cover file")
assert len(album.cover_pictures) == 1
@pytest.mark.asyncio
async def test_model_crud():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Jamaica")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model)
assert track.album.name is None
await track.album.load()
assert track.album.name == "Jamaica"
assert len(album.tracks) == 3
assert album.tracks[1].title == "Heart don't stand a chance"
album1 = await Album.objects.get(name="Jamaica")
assert album1.pk == album.pk
assert album1.tracks == []
await Track.objects.create(
album={"id": track.album.pk}, title="The Bird2", position=4
)
@pytest.mark.asyncio
async def test_select_related():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Malibu")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
fantasies = Album(name="Fantasies")
await fantasies.save()
track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
track5 = Track(album=fantasies, title="Sick Muse", position=2)
track6 = Track(album=fantasies, title="Satellite Mind", position=3)
await track4.save()
await track5.save()
await track6.save()
track = await Track.objects.select_related("album").get(title="The Bird")
assert track.album.name == "Malibu"
tracks = await Track.objects.select_related("album").all()
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_model_removal_from_relations():
async with database:
async with database.transaction(force_rollback=True):
album = Album(name="Chichi")
await album.save()
track1 = Track(album=album, title="The Birdman", position=1)
track2 = Track(album=album, title="Superman", position=2)
track3 = Track(album=album, title="Wonder Woman", position=3)
await track1.save()
await track2.save()
await track3.save()
assert len(album.tracks) == 3
await album.tracks.remove(track1)
assert len(album.tracks) == 2
assert track1.album is None
await track1.update()
track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None
await album.tracks.add(track1)
assert len(album.tracks) == 3
assert track1.album == album
await track1.update()
track1 = await Track.objects.select_related("album__tracks").get(
title="The Birdman"
)
album = await Album.objects.select_related("tracks").get(name="Chichi")
assert track1.album == album
track1.remove(album)
assert track1.album is None
assert len(album.tracks) == 2
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
@pytest.mark.asyncio
async def test_fk_filter():
async with database:
async with database.transaction(force_rollback=True):
malibu = Album(name="Malibu%")
await malibu.save()
await Track.objects.create(album=malibu, title="The Bird", position=1)
await Track.objects.create(
album=malibu, title="Heart don't stand a chance", position=2
)
await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies")
await Track.objects.create(
album=fantasies, title="Help I'm Alive", position=1
)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(
album=fantasies, title="Satellite Mind", position=3
)
tracks = (
await Track.objects.select_related("album")
.filter(album__name="Fantasies")
.all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = (
await Track.objects.select_related("album")
.filter(album__name__icontains="fan")
.all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Fan").all()
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
assert len(tracks) == 3
tracks = (
await Track.objects.filter(album=malibu).select_related("album").all()
)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
tracks = await Track.objects.select_related("album").all(album=malibu)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
@pytest.mark.asyncio
async def test_multiple_fk():
async with database:
async with database.transaction(force_rollback=True):
acme = await Organisation.objects.create(ident="ACME Ltd")
red_team = await Team.objects.create(org=acme, name="Red Team")
blue_team = await Team.objects.create(org=acme, name="Blue Team")
await Member.objects.create(team=red_team, email="a@example.org")
await Member.objects.create(team=red_team, email="b@example.org")
await Member.objects.create(team=blue_team, email="c@example.org")
await Member.objects.create(team=blue_team, email="d@example.org")
other = await Organisation.objects.create(ident="Other ltd")
team = await Team.objects.create(org=other, name="Green Team")
await Member.objects.create(team=team, email="e@example.org")
members = (
await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd")
.all()
)
assert len(members) == 4
for member in members:
assert member.team.org.ident == "ACME Ltd"
@pytest.mark.asyncio
async def test_wrong_choices():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(ValueError):
await Organisation.objects.create(ident="Test 1")
@pytest.mark.asyncio
async def test_pk_filter():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
track = await Track.objects.create(
album=fantasies, title="Test1", position=1
)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
tracks = (
await Track.objects.select_related("album").filter(pk=track.pk).all()
)
assert len(tracks) == 1
tracks = (
await Track.objects.select_related("album")
.filter(position=2, album__name="Test")
.all()
)
assert len(tracks) == 1
@pytest.mark.asyncio
async def test_limit_and_offset():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Limitless")
await Track.objects.create(
id=None, album=fantasies, title="Sample", position=1
)
await Track.objects.create(album=fantasies, title="Sample2", position=2)
await Track.objects.create(album=fantasies, title="Sample3", position=3)
tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample"
tracks = await Track.objects.limit(1).offset(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample2"
@pytest.mark.asyncio
async def test_get_exceptions():
async with database:
async with database.transaction(force_rollback=True):
fantasies = await Album.objects.create(name="Test")
with pytest.raises(NoMatch):
await Album.objects.get(name="Test2")
await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
with pytest.raises(MultipleMatches):
await Track.objects.select_related("album").get(album=fantasies)
@pytest.mark.asyncio
async def test_wrong_model_passed_as_fk():
async with database:
async with database.transaction(force_rollback=True):
with pytest.raises(RelationshipInstanceError):
org = await Organisation.objects.create(ident="ACME Ltd")
await Track.objects.create(album=org, title="Test1", position=1)

View File

@ -176,27 +176,27 @@ async def test_sort_order_on_related_model():
owner = ( owner = (
await Owner.objects.select_related("toys") await Owner.objects.select_related("toys")
.order_by("toys__name") .order_by("toys__name")
.filter(name="Zeus") .filter(name="Zeus")
.get() .get()
) )
assert owner.toys[0].name == "Toy 1" assert owner.toys[0].name == "Toy 1"
assert owner.toys[1].name == "Toy 4" assert owner.toys[1].name == "Toy 4"
owner = ( owner = (
await Owner.objects.select_related("toys") await Owner.objects.select_related("toys")
.order_by("-toys__name") .order_by("-toys__name")
.filter(name="Zeus") .filter(name="Zeus")
.get() .get()
) )
assert owner.toys[0].name == "Toy 4" assert owner.toys[0].name == "Toy 4"
assert owner.toys[1].name == "Toy 1" assert owner.toys[1].name == "Toy 1"
owners = ( owners = (
await Owner.objects.select_related("toys") await Owner.objects.select_related("toys")
.order_by("-toys__name") .order_by("-toys__name")
.filter(name__in=["Zeus", "Hermes"]) .filter(name__in=["Zeus", "Hermes"])
.all() .all()
) )
assert owners[0].toys[0].name == "Toy 6" assert owners[0].toys[0].name == "Toy 6"
assert owners[0].toys[1].name == "Toy 5" assert owners[0].toys[1].name == "Toy 5"
@ -210,9 +210,9 @@ async def test_sort_order_on_related_model():
owners = ( owners = (
await Owner.objects.select_related("toys") await Owner.objects.select_related("toys")
.order_by("-toys__name") .order_by("-toys__name")
.filter(name__in=["Zeus", "Hermes"]) .filter(name__in=["Zeus", "Hermes"])
.all() .all()
) )
assert owners[0].toys[0].name == "Toy 7" assert owners[0].toys[0].name == "Toy 7"
assert owners[0].toys[1].name == "Toy 4" assert owners[0].toys[1].name == "Toy 4"
@ -252,9 +252,9 @@ async def test_sort_order_on_many_to_many():
user = ( user = (
await User.objects.select_related("cars") await User.objects.select_related("cars")
.filter(name="Mark") .filter(name="Mark")
.order_by("cars__name") .order_by("cars__name")
.get() .get()
) )
assert user.cars[0].name == "Buggy" assert user.cars[0].name == "Buggy"
assert user.cars[1].name == "Ferrari" assert user.cars[1].name == "Ferrari"
@ -263,9 +263,9 @@ async def test_sort_order_on_many_to_many():
user = ( user = (
await User.objects.select_related("cars") await User.objects.select_related("cars")
.filter(name="Mark") .filter(name="Mark")
.order_by("-cars__name") .order_by("-cars__name")
.get() .get()
) )
assert user.cars[3].name == "Buggy" assert user.cars[3].name == "Buggy"
assert user.cars[2].name == "Ferrari" assert user.cars[2].name == "Ferrari"
@ -281,8 +281,8 @@ async def test_sort_order_on_many_to_many():
users = ( users = (
await User.objects.select_related(["cars", "cars__factory"]) await User.objects.select_related(["cars", "cars__factory"])
.order_by(["-cars__factory__name", "cars__name"]) .order_by(["-cars__factory__name", "cars__name"])
.all() .all()
) )
assert users[0].name == "Julie" assert users[0].name == "Julie"
@ -328,8 +328,8 @@ async def test_sort_order_with_aliases():
aliases = ( aliases = (
await AliasTest.objects.select_related("nested") await AliasTest.objects.select_related("nested")
.order_by("-nested__name") .order_by("-nested__name")
.all() .all()
) )
assert aliases[0].nested.name == "Try4" assert aliases[0].nested.name == "Try4"
assert aliases[1].nested.name == "Try3" assert aliases[1].nested.name == "Try3"