some cleanup of unused relations code, introduced caching of related_names and props on model, set profiling
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ test.db
|
|||||||
dist
|
dist
|
||||||
/ormar.egg-info/
|
/ormar.egg-info/
|
||||||
site
|
site
|
||||||
|
profile.py
|
||||||
|
|||||||
@ -12,11 +12,17 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import ormar
|
|
||||||
from ormar.exceptions import RelationshipInstanceError
|
from ormar.exceptions import RelationshipInstanceError
|
||||||
|
|
||||||
|
try:
|
||||||
|
import orjson as json
|
||||||
|
except ImportError: # pragma: nocover
|
||||||
|
import json # type: ignore
|
||||||
|
|
||||||
|
import ormar
|
||||||
from ormar.fields import BaseField, ManyToManyField
|
from ormar.fields import BaseField, ManyToManyField
|
||||||
from ormar.fields.foreign_key import ForeignKeyField
|
from ormar.fields.foreign_key import ForeignKeyField
|
||||||
from ormar.models.metaclass import ModelMeta, expand_reverse_relationships
|
from ormar.models.metaclass import ModelMeta
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
@ -30,6 +36,8 @@ Field = TypeVar("Field", bound=BaseField)
|
|||||||
class ModelTableProxy:
|
class ModelTableProxy:
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
Meta: ModelMeta
|
Meta: ModelMeta
|
||||||
|
_related_names: Set
|
||||||
|
_related_names_hash: Union[str, bytes]
|
||||||
|
|
||||||
def dict(self): # noqa A003
|
def dict(self): # noqa A003
|
||||||
raise NotImplementedError # pragma no cover
|
raise NotImplementedError # pragma no cover
|
||||||
@ -88,10 +96,17 @@ class ModelTableProxy:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_related_names(cls) -> Set:
|
def extract_related_names(cls) -> Set:
|
||||||
|
|
||||||
|
if isinstance(cls._related_names_hash, (str, bytes)):
|
||||||
|
return cls._related_names
|
||||||
|
|
||||||
related_names = set()
|
related_names = set()
|
||||||
for name, field in cls.Meta.model_fields.items():
|
for name, field in cls.Meta.model_fields.items():
|
||||||
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
|
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
|
||||||
related_names.add(name)
|
related_names.add(name)
|
||||||
|
cls._related_names_hash = json.dumps(list(cls.Meta.model_fields.keys()))
|
||||||
|
cls._related_names = related_names
|
||||||
|
|
||||||
return related_names
|
return related_names
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -99,10 +114,10 @@ class ModelTableProxy:
|
|||||||
related_names = set()
|
related_names = set()
|
||||||
for name, field in cls.Meta.model_fields.items():
|
for name, field in cls.Meta.model_fields.items():
|
||||||
if (
|
if (
|
||||||
inspect.isclass(field)
|
inspect.isclass(field)
|
||||||
and issubclass(field, ForeignKeyField)
|
and issubclass(field, ForeignKeyField)
|
||||||
and not issubclass(field, ManyToManyField)
|
and not issubclass(field, ManyToManyField)
|
||||||
and not field.virtual
|
and not field.virtual
|
||||||
):
|
):
|
||||||
related_names.add(name)
|
related_names.add(name)
|
||||||
return related_names
|
return related_names
|
||||||
@ -114,9 +129,9 @@ class ModelTableProxy:
|
|||||||
related_names = set()
|
related_names = set()
|
||||||
for name, field in cls.Meta.model_fields.items():
|
for name, field in cls.Meta.model_fields.items():
|
||||||
if (
|
if (
|
||||||
inspect.isclass(field)
|
inspect.isclass(field)
|
||||||
and issubclass(field, ForeignKeyField)
|
and issubclass(field, ForeignKeyField)
|
||||||
and field.nullable
|
and field.nullable
|
||||||
):
|
):
|
||||||
related_names.add(name)
|
related_names.add(name)
|
||||||
return related_names
|
return related_names
|
||||||
@ -136,9 +151,8 @@ class ModelTableProxy:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_relation_name( # noqa CCR001
|
def resolve_relation_name( # noqa CCR001
|
||||||
item: Union["NewBaseModel", Type["NewBaseModel"]],
|
item: Union["NewBaseModel", Type["NewBaseModel"]],
|
||||||
related: Union["NewBaseModel", Type["NewBaseModel"]],
|
related: Union["NewBaseModel", Type["NewBaseModel"]]
|
||||||
register_missing: bool = True,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
for name, field in item.Meta.model_fields.items():
|
for name, field in item.Meta.model_fields.items():
|
||||||
if issubclass(field, ForeignKeyField):
|
if issubclass(field, ForeignKeyField):
|
||||||
@ -147,12 +161,6 @@ class ModelTableProxy:
|
|||||||
# so we need to compare Meta too as this one is copied as is
|
# so we need to compare Meta too as this one is copied as is
|
||||||
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
||||||
return name
|
return name
|
||||||
# fallback for not registered relation
|
|
||||||
if register_missing: # pragma nocover
|
|
||||||
expand_reverse_relationships(related.__class__) # type: ignore
|
|
||||||
return ModelTableProxy.resolve_relation_name(
|
|
||||||
item, related, register_missing=False
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No relation between {item.get_name()} and {related.get_name()}"
|
f"No relation between {item.get_name()} and {related.get_name()}"
|
||||||
@ -160,8 +168,8 @@ class ModelTableProxy:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resolve_relation_field(
|
def resolve_relation_field(
|
||||||
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
|
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
|
||||||
) -> Union[Type[BaseField], Type[ForeignKeyField]]:
|
) -> Type[BaseField]:
|
||||||
name = ModelTableProxy.resolve_relation_name(item, related)
|
name = ModelTableProxy.resolve_relation_name(item, related)
|
||||||
to_field = item.Meta.model_fields.get(name)
|
to_field = item.Meta.model_fields.get(name)
|
||||||
if not to_field: # pragma no cover
|
if not to_field: # pragma no cover
|
||||||
@ -207,12 +215,12 @@ class ModelTableProxy:
|
|||||||
for field in one.Meta.model_fields.keys():
|
for field in one.Meta.model_fields.keys():
|
||||||
current_field = getattr(one, field)
|
current_field = getattr(one, field)
|
||||||
if isinstance(current_field, list) and not isinstance(
|
if isinstance(current_field, list) and not isinstance(
|
||||||
current_field, ormar.Model
|
current_field, ormar.Model
|
||||||
):
|
):
|
||||||
setattr(other, field, current_field + getattr(other, field))
|
setattr(other, field, current_field + getattr(other, field))
|
||||||
elif (
|
elif (
|
||||||
isinstance(current_field, ormar.Model)
|
isinstance(current_field, ormar.Model)
|
||||||
and current_field.pk == getattr(other, field).pk
|
and current_field.pk == getattr(other, field).pk
|
||||||
):
|
):
|
||||||
setattr(
|
setattr(
|
||||||
other,
|
other,
|
||||||
@ -223,7 +231,7 @@ class ModelTableProxy:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _populate_pk_column(
|
def _populate_pk_column(
|
||||||
model: Type["Model"], columns: List[str], use_alias: bool = False,
|
model: Type["Model"], columns: List[str], use_alias: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
pk_alias = (
|
pk_alias = (
|
||||||
model.get_column_alias(model.Meta.pkname)
|
model.get_column_alias(model.Meta.pkname)
|
||||||
@ -236,10 +244,10 @@ class ModelTableProxy:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def own_table_columns(
|
def own_table_columns(
|
||||||
model: Type["Model"],
|
model: Type["Model"],
|
||||||
fields: Optional[Union[Set, Dict]],
|
fields: Optional[Union[Set, Dict]],
|
||||||
exclude_fields: Optional[Union[Set, Dict]],
|
exclude_fields: Optional[Union[Set, Dict]],
|
||||||
use_alias: bool = False,
|
use_alias: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
columns = [
|
columns = [
|
||||||
model.get_column_name_from_alias(col.name) if not use_alias else col.name
|
model.get_column_name_from_alias(col.name) if not use_alias else col.name
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from typing import (
|
|||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
TYPE_CHECKING,
|
Set, TYPE_CHECKING,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -43,7 +43,7 @@ if TYPE_CHECKING: # pragma no cover
|
|||||||
class NewBaseModel(
|
class NewBaseModel(
|
||||||
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
|
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
|
||||||
):
|
):
|
||||||
__slots__ = ("_orm_id", "_orm_saved", "_orm")
|
__slots__ = ("_orm_id", "_orm_saved", "_orm", "_related_names", "_related_names_hash", "_props")
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
__model_fields__: Dict[str, Type[BaseField]]
|
__model_fields__: Dict[str, Type[BaseField]]
|
||||||
@ -56,6 +56,10 @@ class NewBaseModel(
|
|||||||
__database__: databases.Database
|
__database__: databases.Database
|
||||||
_orm_relationship_manager: AliasManager
|
_orm_relationship_manager: AliasManager
|
||||||
_orm: RelationsManager
|
_orm: RelationsManager
|
||||||
|
_orm_saved: bool
|
||||||
|
_related_names: Set
|
||||||
|
_related_names_hash: str
|
||||||
|
_props: List[str]
|
||||||
Meta: ModelMeta
|
Meta: ModelMeta
|
||||||
|
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
@ -107,7 +111,7 @@ class NewBaseModel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
|
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
|
||||||
if name in ("_orm_id", "_orm_saved", "_orm"):
|
if name in ("_orm_id", "_orm_saved", "_orm", "_related_names", "_props"):
|
||||||
object.__setattr__(self, name, value)
|
object.__setattr__(self, name, value)
|
||||||
elif name == "pk":
|
elif name == "pk":
|
||||||
object.__setattr__(self, self.Meta.pkname, value)
|
object.__setattr__(self, self.Meta.pkname, value)
|
||||||
@ -126,12 +130,12 @@ class NewBaseModel(
|
|||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
def __getattribute__(self, item: str) -> Any:
|
def __getattribute__(self, item: str) -> Any:
|
||||||
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
|
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__", "_related_names", "_props"):
|
||||||
return object.__getattribute__(self, item)
|
return object.__getattribute__(self, item)
|
||||||
if item != "extract_related_names" and item in self.extract_related_names():
|
|
||||||
return self._extract_related_model_instead_of_field(item)
|
|
||||||
if item == "pk":
|
if item == "pk":
|
||||||
return self.__dict__.get(self.Meta.pkname, None)
|
return self.__dict__.get(self.Meta.pkname, None)
|
||||||
|
if item != "extract_related_names" and item in self.extract_related_names():
|
||||||
|
return self._extract_related_model_instead_of_field(item)
|
||||||
if item != "__fields__" and item in self.__fields__:
|
if item != "__fields__" and item in self.__fields__:
|
||||||
value = self.__dict__.get(item, None)
|
value = self.__dict__.get(item, None)
|
||||||
value = self._convert_json(item, value, "loads")
|
value = self._convert_json(item, value, "loads")
|
||||||
@ -186,12 +190,16 @@ class NewBaseModel(
|
|||||||
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
|
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
|
||||||
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
|
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
props = [
|
if isinstance(cls._props, list):
|
||||||
prop
|
props = cls._props
|
||||||
for prop in dir(cls)
|
else:
|
||||||
if isinstance(getattr(cls, prop), property)
|
props = [
|
||||||
and prop not in ("__values__", "__fields__", "fields", "pk_column")
|
prop
|
||||||
]
|
for prop in dir(cls)
|
||||||
|
if isinstance(getattr(cls, prop), property)
|
||||||
|
and prop not in ("__values__", "__fields__", "fields", "pk_column")
|
||||||
|
]
|
||||||
|
cls._props = props
|
||||||
if include:
|
if include:
|
||||||
props = [prop for prop in props if prop in include]
|
props = [prop for prop in props if prop in include]
|
||||||
if exclude:
|
if exclude:
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from ormar.relations.relation import Relation, RelationType
|
|||||||
from ormar.relations.relation_manager import RelationsManager
|
from ormar.relations.relation_manager import RelationsManager
|
||||||
from ormar.relations.utils import (
|
from ormar.relations.utils import (
|
||||||
get_relations_sides_and_names,
|
get_relations_sides_and_names,
|
||||||
register_missing_relation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -11,6 +10,5 @@ __all__ = [
|
|||||||
"Relation",
|
"Relation",
|
||||||
"RelationsManager",
|
"RelationsManager",
|
||||||
"RelationType",
|
"RelationType",
|
||||||
"register_missing_relation",
|
|
||||||
"get_relations_sides_and_names",
|
"get_relations_sides_and_names",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from ormar.fields.many_to_many import ManyToManyField
|
|||||||
from ormar.relations.relation import Relation, RelationType
|
from ormar.relations.relation import Relation, RelationType
|
||||||
from ormar.relations.utils import (
|
from ormar.relations.utils import (
|
||||||
get_relations_sides_and_names,
|
get_relations_sides_and_names,
|
||||||
register_missing_relation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
@ -42,8 +41,6 @@ class RelationsManager:
|
|||||||
to=field.to,
|
to=field.to,
|
||||||
through=getattr(field, "through", None),
|
through=getattr(field, "through", None),
|
||||||
)
|
)
|
||||||
if field.name not in self._related_names:
|
|
||||||
self._related_names.append(field.name)
|
|
||||||
|
|
||||||
def __contains__(self, item: str) -> bool:
|
def __contains__(self, item: str) -> bool:
|
||||||
return item in self._related_names
|
return item in self._related_names
|
||||||
@ -69,9 +66,10 @@ class RelationsManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
parent_relation = parent._orm._get(child_name)
|
parent_relation = parent._orm._get(child_name)
|
||||||
if not parent_relation:
|
if parent_relation:
|
||||||
parent_relation = register_missing_relation(parent, child, child_name)
|
# print('missing', child_name)
|
||||||
parent_relation.add(child) # type: ignore
|
# parent_relation = register_missing_relation(parent, child, child_name)
|
||||||
|
parent_relation.add(child) # type: ignore
|
||||||
|
|
||||||
child_relation = child._orm._get(to_name)
|
child_relation = child._orm._get(to_name)
|
||||||
if child_relation:
|
if child_relation:
|
||||||
|
|||||||
@ -72,6 +72,4 @@ class RelationProxy(list):
|
|||||||
if self.relation._type == ormar.RelationType.MULTIPLE:
|
if self.relation._type == ormar.RelationType.MULTIPLE:
|
||||||
await self.queryset_proxy.create_through_instance(item)
|
await self.queryset_proxy.create_through_instance(item)
|
||||||
rel_name = item.resolve_relation_name(item, self._owner)
|
rel_name = item.resolve_relation_name(item, self._owner)
|
||||||
if rel_name not in item._orm: # pragma nocover
|
|
||||||
item._orm._add_relation(item.Meta.model_fields[rel_name])
|
|
||||||
setattr(item, rel_name, self._owner)
|
setattr(item, rel_name, self._owner)
|
||||||
|
|||||||
@ -1,32 +1,19 @@
|
|||||||
from typing import Optional, TYPE_CHECKING, Tuple, Type
|
from typing import TYPE_CHECKING, Tuple, Type
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
|
||||||
import ormar
|
|
||||||
from ormar.fields import BaseField
|
from ormar.fields import BaseField
|
||||||
from ormar.fields.many_to_many import ManyToManyField
|
from ormar.fields.many_to_many import ManyToManyField
|
||||||
from ormar.relations import Relation
|
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma no cover
|
if TYPE_CHECKING: # pragma no cover
|
||||||
from ormar import Model
|
from ormar import Model
|
||||||
|
|
||||||
|
|
||||||
def register_missing_relation(
|
|
||||||
parent: "Model", child: "Model", child_name: str
|
|
||||||
) -> Optional[Relation]:
|
|
||||||
ormar.models.expand_reverse_relationships(child.__class__)
|
|
||||||
name = parent.resolve_relation_name(parent, child)
|
|
||||||
field = parent.Meta.model_fields[name]
|
|
||||||
parent._orm._add_relation(field)
|
|
||||||
parent_relation = parent._orm._get(child_name)
|
|
||||||
return parent_relation
|
|
||||||
|
|
||||||
|
|
||||||
def get_relations_sides_and_names(
|
def get_relations_sides_and_names(
|
||||||
to_field: Type[BaseField],
|
to_field: Type[BaseField],
|
||||||
parent: "Model",
|
parent: "Model",
|
||||||
child: "Model",
|
child: "Model",
|
||||||
child_name: str,
|
child_name: str,
|
||||||
virtual: bool,
|
virtual: bool,
|
||||||
) -> Tuple["Model", "Model", str, str]:
|
) -> Tuple["Model", "Model", str, str]:
|
||||||
to_name = to_field.name
|
to_name = to_field.name
|
||||||
if issubclass(to_field, ManyToManyField):
|
if issubclass(to_field, ManyToManyField):
|
||||||
|
|||||||
@ -4,6 +4,7 @@ databases[mysql]
|
|||||||
pydantic
|
pydantic
|
||||||
sqlalchemy
|
sqlalchemy
|
||||||
typing_extensions
|
typing_extensions
|
||||||
|
orjson
|
||||||
|
|
||||||
# Async database drivers
|
# Async database drivers
|
||||||
aiomysql
|
aiomysql
|
||||||
@ -34,3 +35,6 @@ flake8-variables-names
|
|||||||
flake8-cognitive-complexity
|
flake8-cognitive-complexity
|
||||||
flake8-functions
|
flake8-functions
|
||||||
flake8-expression-complexity
|
flake8-expression-complexity
|
||||||
|
|
||||||
|
# Performance testing
|
||||||
|
yappi
|
||||||
|
|||||||
3
setup.py
3
setup.py
@ -42,7 +42,7 @@ setup(
|
|||||||
version=get_version(PACKAGE),
|
version=get_version(PACKAGE),
|
||||||
url=URL,
|
url=URL,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
description="An simple async ORM with fastapi in mind and pydantic validation.",
|
description="A simple async ORM with fastapi in mind and pydantic validation.",
|
||||||
long_description=get_long_description(),
|
long_description=get_long_description(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'],
|
keywords=['orm', 'sqlalchemy', 'fastapi', 'pydantic', 'databases', 'async', 'alembic'],
|
||||||
@ -56,6 +56,7 @@ setup(
|
|||||||
"postgresql": ["asyncpg", "psycopg2"],
|
"postgresql": ["asyncpg", "psycopg2"],
|
||||||
"mysql": ["aiomysql", "pymysql"],
|
"mysql": ["aiomysql", "pymysql"],
|
||||||
"sqlite": ["aiosqlite"],
|
"sqlite": ["aiosqlite"],
|
||||||
|
"orjson": ["orjson"]
|
||||||
},
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
|
|||||||
@ -1,374 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import databases
|
|
||||||
import pytest
|
|
||||||
import sqlalchemy
|
|
||||||
|
|
||||||
import ormar
|
|
||||||
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
|
|
||||||
from tests.settings import DATABASE_URL
|
|
||||||
|
|
||||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
|
||||||
metadata = sqlalchemy.MetaData()
|
|
||||||
|
|
||||||
|
|
||||||
class Album(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "albums"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
name: str = ormar.String(max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class Track(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "tracks"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
album: Optional[Album] = ormar.ForeignKey(Album)
|
|
||||||
title: str = ormar.String(max_length=100)
|
|
||||||
position: int = ormar.Integer()
|
|
||||||
|
|
||||||
|
|
||||||
class Cover(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "covers"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
album: Optional[Album] = ormar.ForeignKey(Album, related_name="cover_pictures")
|
|
||||||
title: str = ormar.String(max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class Organisation(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "org"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
ident: str = ormar.String(max_length=100, choices=["ACME Ltd", "Other ltd"])
|
|
||||||
|
|
||||||
|
|
||||||
class Team(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "teams"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
org: Optional[Organisation] = ormar.ForeignKey(Organisation)
|
|
||||||
name: str = ormar.String(max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
class Member(ormar.Model):
|
|
||||||
class Meta:
|
|
||||||
tablename = "members"
|
|
||||||
metadata = metadata
|
|
||||||
database = database
|
|
||||||
|
|
||||||
id: int = ormar.Integer(primary_key=True)
|
|
||||||
team: Optional[Team] = ormar.ForeignKey(Team)
|
|
||||||
email: str = ormar.String(max_length=100)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True, scope="module")
|
|
||||||
def create_test_database():
|
|
||||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
|
||||||
metadata.drop_all(engine)
|
|
||||||
metadata.create_all(engine)
|
|
||||||
yield
|
|
||||||
metadata.drop_all(engine)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_wrong_query_foreign_key_type():
|
|
||||||
async with database:
|
|
||||||
with pytest.raises(RelationshipInstanceError):
|
|
||||||
Track(title="The Error", album="wrong_pk_type")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_setting_explicitly_empty_relation():
|
|
||||||
async with database:
|
|
||||||
track = Track(album=None, title="The Bird", position=1)
|
|
||||||
assert track.album is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_related_name():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
album = await Album.objects.create(name="Vanilla")
|
|
||||||
await Cover.objects.create(album=album, title="The cover file")
|
|
||||||
assert len(album.cover_pictures) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_model_crud():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
album = Album(name="Jamaica")
|
|
||||||
await album.save()
|
|
||||||
track1 = Track(album=album, title="The Bird", position=1)
|
|
||||||
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
|
|
||||||
track3 = Track(album=album, title="The Waters", position=3)
|
|
||||||
await track1.save()
|
|
||||||
await track2.save()
|
|
||||||
await track3.save()
|
|
||||||
|
|
||||||
track = await Track.objects.get(title="The Bird")
|
|
||||||
assert track.album.pk == album.pk
|
|
||||||
assert isinstance(track.album, ormar.Model)
|
|
||||||
assert track.album.name is None
|
|
||||||
await track.album.load()
|
|
||||||
assert track.album.name == "Jamaica"
|
|
||||||
|
|
||||||
assert len(album.tracks) == 3
|
|
||||||
assert album.tracks[1].title == "Heart don't stand a chance"
|
|
||||||
|
|
||||||
album1 = await Album.objects.get(name="Jamaica")
|
|
||||||
assert album1.pk == album.pk
|
|
||||||
assert album1.tracks == []
|
|
||||||
|
|
||||||
await Track.objects.create(
|
|
||||||
album={"id": track.album.pk}, title="The Bird2", position=4
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_select_related():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
album = Album(name="Malibu")
|
|
||||||
await album.save()
|
|
||||||
track1 = Track(album=album, title="The Bird", position=1)
|
|
||||||
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
|
|
||||||
track3 = Track(album=album, title="The Waters", position=3)
|
|
||||||
await track1.save()
|
|
||||||
await track2.save()
|
|
||||||
await track3.save()
|
|
||||||
|
|
||||||
fantasies = Album(name="Fantasies")
|
|
||||||
await fantasies.save()
|
|
||||||
track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
|
|
||||||
track5 = Track(album=fantasies, title="Sick Muse", position=2)
|
|
||||||
track6 = Track(album=fantasies, title="Satellite Mind", position=3)
|
|
||||||
await track4.save()
|
|
||||||
await track5.save()
|
|
||||||
await track6.save()
|
|
||||||
|
|
||||||
track = await Track.objects.select_related("album").get(title="The Bird")
|
|
||||||
assert track.album.name == "Malibu"
|
|
||||||
|
|
||||||
tracks = await Track.objects.select_related("album").all()
|
|
||||||
assert len(tracks) == 6
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_model_removal_from_relations():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
album = Album(name="Chichi")
|
|
||||||
await album.save()
|
|
||||||
track1 = Track(album=album, title="The Birdman", position=1)
|
|
||||||
track2 = Track(album=album, title="Superman", position=2)
|
|
||||||
track3 = Track(album=album, title="Wonder Woman", position=3)
|
|
||||||
await track1.save()
|
|
||||||
await track2.save()
|
|
||||||
await track3.save()
|
|
||||||
|
|
||||||
assert len(album.tracks) == 3
|
|
||||||
await album.tracks.remove(track1)
|
|
||||||
assert len(album.tracks) == 2
|
|
||||||
assert track1.album is None
|
|
||||||
|
|
||||||
await track1.update()
|
|
||||||
track1 = await Track.objects.get(title="The Birdman")
|
|
||||||
assert track1.album is None
|
|
||||||
|
|
||||||
await album.tracks.add(track1)
|
|
||||||
assert len(album.tracks) == 3
|
|
||||||
assert track1.album == album
|
|
||||||
|
|
||||||
await track1.update()
|
|
||||||
track1 = await Track.objects.select_related("album__tracks").get(
|
|
||||||
title="The Birdman"
|
|
||||||
)
|
|
||||||
album = await Album.objects.select_related("tracks").get(name="Chichi")
|
|
||||||
assert track1.album == album
|
|
||||||
|
|
||||||
track1.remove(album)
|
|
||||||
assert track1.album is None
|
|
||||||
assert len(album.tracks) == 2
|
|
||||||
|
|
||||||
track2.remove(album)
|
|
||||||
assert track2.album is None
|
|
||||||
assert len(album.tracks) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_fk_filter():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
malibu = Album(name="Malibu%")
|
|
||||||
await malibu.save()
|
|
||||||
await Track.objects.create(album=malibu, title="The Bird", position=1)
|
|
||||||
await Track.objects.create(
|
|
||||||
album=malibu, title="Heart don't stand a chance", position=2
|
|
||||||
)
|
|
||||||
await Track.objects.create(album=malibu, title="The Waters", position=3)
|
|
||||||
|
|
||||||
fantasies = await Album.objects.create(name="Fantasies")
|
|
||||||
await Track.objects.create(
|
|
||||||
album=fantasies, title="Help I'm Alive", position=1
|
|
||||||
)
|
|
||||||
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
|
|
||||||
await Track.objects.create(
|
|
||||||
album=fantasies, title="Satellite Mind", position=3
|
|
||||||
)
|
|
||||||
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.select_related("album")
|
|
||||||
.filter(album__name="Fantasies")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Fantasies"
|
|
||||||
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.select_related("album")
|
|
||||||
.filter(album__name__icontains="fan")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Fantasies"
|
|
||||||
|
|
||||||
tracks = await Track.objects.filter(album__name__contains="Fan").all()
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Fantasies"
|
|
||||||
|
|
||||||
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
|
|
||||||
assert len(tracks) == 3
|
|
||||||
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.filter(album=malibu).select_related("album").all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Malibu%"
|
|
||||||
|
|
||||||
tracks = await Track.objects.select_related("album").all(album=malibu)
|
|
||||||
assert len(tracks) == 3
|
|
||||||
for track in tracks:
|
|
||||||
assert track.album.name == "Malibu%"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multiple_fk():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
acme = await Organisation.objects.create(ident="ACME Ltd")
|
|
||||||
red_team = await Team.objects.create(org=acme, name="Red Team")
|
|
||||||
blue_team = await Team.objects.create(org=acme, name="Blue Team")
|
|
||||||
await Member.objects.create(team=red_team, email="a@example.org")
|
|
||||||
await Member.objects.create(team=red_team, email="b@example.org")
|
|
||||||
await Member.objects.create(team=blue_team, email="c@example.org")
|
|
||||||
await Member.objects.create(team=blue_team, email="d@example.org")
|
|
||||||
|
|
||||||
other = await Organisation.objects.create(ident="Other ltd")
|
|
||||||
team = await Team.objects.create(org=other, name="Green Team")
|
|
||||||
await Member.objects.create(team=team, email="e@example.org")
|
|
||||||
|
|
||||||
members = (
|
|
||||||
await Member.objects.select_related("team__org")
|
|
||||||
.filter(team__org__ident="ACME Ltd")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
assert len(members) == 4
|
|
||||||
for member in members:
|
|
||||||
assert member.team.org.ident == "ACME Ltd"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_wrong_choices():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await Organisation.objects.create(ident="Test 1")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_pk_filter():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
fantasies = await Album.objects.create(name="Test")
|
|
||||||
track = await Track.objects.create(
|
|
||||||
album=fantasies, title="Test1", position=1
|
|
||||||
)
|
|
||||||
await Track.objects.create(album=fantasies, title="Test2", position=2)
|
|
||||||
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.select_related("album").filter(pk=track.pk).all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 1
|
|
||||||
|
|
||||||
tracks = (
|
|
||||||
await Track.objects.select_related("album")
|
|
||||||
.filter(position=2, album__name="Test")
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
assert len(tracks) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_limit_and_offset():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
fantasies = await Album.objects.create(name="Limitless")
|
|
||||||
await Track.objects.create(
|
|
||||||
id=None, album=fantasies, title="Sample", position=1
|
|
||||||
)
|
|
||||||
await Track.objects.create(album=fantasies, title="Sample2", position=2)
|
|
||||||
await Track.objects.create(album=fantasies, title="Sample3", position=3)
|
|
||||||
|
|
||||||
tracks = await Track.objects.limit(1).all()
|
|
||||||
assert len(tracks) == 1
|
|
||||||
assert tracks[0].title == "Sample"
|
|
||||||
|
|
||||||
tracks = await Track.objects.limit(1).offset(1).all()
|
|
||||||
assert len(tracks) == 1
|
|
||||||
assert tracks[0].title == "Sample2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_exceptions():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
fantasies = await Album.objects.create(name="Test")
|
|
||||||
|
|
||||||
with pytest.raises(NoMatch):
|
|
||||||
await Album.objects.get(name="Test2")
|
|
||||||
|
|
||||||
await Track.objects.create(album=fantasies, title="Test1", position=1)
|
|
||||||
await Track.objects.create(album=fantasies, title="Test2", position=2)
|
|
||||||
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
|
||||||
with pytest.raises(MultipleMatches):
|
|
||||||
await Track.objects.select_related("album").get(album=fantasies)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_wrong_model_passed_as_fk():
|
|
||||||
async with database:
|
|
||||||
async with database.transaction(force_rollback=True):
|
|
||||||
with pytest.raises(RelationshipInstanceError):
|
|
||||||
org = await Organisation.objects.create(ident="ACME Ltd")
|
|
||||||
await Track.objects.create(album=org, title="Test1", position=1)
|
|
||||||
@ -176,27 +176,27 @@ async def test_sort_order_on_related_model():
|
|||||||
|
|
||||||
owner = (
|
owner = (
|
||||||
await Owner.objects.select_related("toys")
|
await Owner.objects.select_related("toys")
|
||||||
.order_by("toys__name")
|
.order_by("toys__name")
|
||||||
.filter(name="Zeus")
|
.filter(name="Zeus")
|
||||||
.get()
|
.get()
|
||||||
)
|
)
|
||||||
assert owner.toys[0].name == "Toy 1"
|
assert owner.toys[0].name == "Toy 1"
|
||||||
assert owner.toys[1].name == "Toy 4"
|
assert owner.toys[1].name == "Toy 4"
|
||||||
|
|
||||||
owner = (
|
owner = (
|
||||||
await Owner.objects.select_related("toys")
|
await Owner.objects.select_related("toys")
|
||||||
.order_by("-toys__name")
|
.order_by("-toys__name")
|
||||||
.filter(name="Zeus")
|
.filter(name="Zeus")
|
||||||
.get()
|
.get()
|
||||||
)
|
)
|
||||||
assert owner.toys[0].name == "Toy 4"
|
assert owner.toys[0].name == "Toy 4"
|
||||||
assert owner.toys[1].name == "Toy 1"
|
assert owner.toys[1].name == "Toy 1"
|
||||||
|
|
||||||
owners = (
|
owners = (
|
||||||
await Owner.objects.select_related("toys")
|
await Owner.objects.select_related("toys")
|
||||||
.order_by("-toys__name")
|
.order_by("-toys__name")
|
||||||
.filter(name__in=["Zeus", "Hermes"])
|
.filter(name__in=["Zeus", "Hermes"])
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
assert owners[0].toys[0].name == "Toy 6"
|
assert owners[0].toys[0].name == "Toy 6"
|
||||||
assert owners[0].toys[1].name == "Toy 5"
|
assert owners[0].toys[1].name == "Toy 5"
|
||||||
@ -210,9 +210,9 @@ async def test_sort_order_on_related_model():
|
|||||||
|
|
||||||
owners = (
|
owners = (
|
||||||
await Owner.objects.select_related("toys")
|
await Owner.objects.select_related("toys")
|
||||||
.order_by("-toys__name")
|
.order_by("-toys__name")
|
||||||
.filter(name__in=["Zeus", "Hermes"])
|
.filter(name__in=["Zeus", "Hermes"])
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
assert owners[0].toys[0].name == "Toy 7"
|
assert owners[0].toys[0].name == "Toy 7"
|
||||||
assert owners[0].toys[1].name == "Toy 4"
|
assert owners[0].toys[1].name == "Toy 4"
|
||||||
@ -252,9 +252,9 @@ async def test_sort_order_on_many_to_many():
|
|||||||
|
|
||||||
user = (
|
user = (
|
||||||
await User.objects.select_related("cars")
|
await User.objects.select_related("cars")
|
||||||
.filter(name="Mark")
|
.filter(name="Mark")
|
||||||
.order_by("cars__name")
|
.order_by("cars__name")
|
||||||
.get()
|
.get()
|
||||||
)
|
)
|
||||||
assert user.cars[0].name == "Buggy"
|
assert user.cars[0].name == "Buggy"
|
||||||
assert user.cars[1].name == "Ferrari"
|
assert user.cars[1].name == "Ferrari"
|
||||||
@ -263,9 +263,9 @@ async def test_sort_order_on_many_to_many():
|
|||||||
|
|
||||||
user = (
|
user = (
|
||||||
await User.objects.select_related("cars")
|
await User.objects.select_related("cars")
|
||||||
.filter(name="Mark")
|
.filter(name="Mark")
|
||||||
.order_by("-cars__name")
|
.order_by("-cars__name")
|
||||||
.get()
|
.get()
|
||||||
)
|
)
|
||||||
assert user.cars[3].name == "Buggy"
|
assert user.cars[3].name == "Buggy"
|
||||||
assert user.cars[2].name == "Ferrari"
|
assert user.cars[2].name == "Ferrari"
|
||||||
@ -281,8 +281,8 @@ async def test_sort_order_on_many_to_many():
|
|||||||
|
|
||||||
users = (
|
users = (
|
||||||
await User.objects.select_related(["cars", "cars__factory"])
|
await User.objects.select_related(["cars", "cars__factory"])
|
||||||
.order_by(["-cars__factory__name", "cars__name"])
|
.order_by(["-cars__factory__name", "cars__name"])
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
assert users[0].name == "Julie"
|
assert users[0].name == "Julie"
|
||||||
@ -328,8 +328,8 @@ async def test_sort_order_with_aliases():
|
|||||||
|
|
||||||
aliases = (
|
aliases = (
|
||||||
await AliasTest.objects.select_related("nested")
|
await AliasTest.objects.select_related("nested")
|
||||||
.order_by("-nested__name")
|
.order_by("-nested__name")
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
assert aliases[0].nested.name == "Try4"
|
assert aliases[0].nested.name == "Try4"
|
||||||
assert aliases[1].nested.name == "Try3"
|
assert aliases[1].nested.name == "Try3"
|
||||||
|
|||||||
Reference in New Issue
Block a user