sloppy work on passing all of the test and reimplementing most of the features from encode

This commit is contained in:
collerek
2020-08-04 18:44:17 +02:00
parent a6f8fc6d7e
commit 345fd227d1
9 changed files with 648 additions and 24 deletions

BIN
.coverage

Binary file not shown.

View File

@ -158,6 +158,7 @@ The following keyword arguments are supported on all field types.
* `primary_key` * `primary_key`
* `nullable` * `nullable`
* `default` * `default`
* `server_default`
* `index` * `index`
* `unique` * `unique`
@ -165,6 +166,9 @@ All fields are required unless one of the following is set:
* `nullable` - Creates a nullable column. Sets the default to `None`. * `nullable` - Creates a nullable column. Sets the default to `None`.
* `default` - Set a default value for the field. * `default` - Set a default value for the field.
* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`).
* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column.
Autoincrement is set by default on int primary keys.
Available Model Fields: Available Model Fields:
* `orm.String(length)` * `orm.String(length)`

View File

@ -1,4 +1,5 @@
from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float, \
ForeignKey
from orm.models import Model from orm.models import Model
__all__ = [ __all__ = [
@ -13,5 +14,6 @@ __all__ = [
"Date", "Date",
"Decimal", "Decimal",
"Float", "Float",
"ForeignKey",
"Model" "Model"
] ]

View File

@ -10,7 +10,11 @@ class ModelNotSet(AsyncOrmException):
pass pass
class MultipleResults(AsyncOrmException): class NoMatch(AsyncOrmException):
pass
class MultipleMatches(AsyncOrmException):
pass pass

View File

@ -6,6 +6,7 @@ import pydantic
import sqlalchemy import sqlalchemy
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError
from orm.relations import Relationship
class BaseField: class BaseField:
@ -24,7 +25,7 @@ class BaseField:
self.name = name self.name = name
self.primary_key = kwargs.pop('primary_key', False) self.primary_key = kwargs.pop('primary_key', False)
self.autoincrement = kwargs.pop('autoincrement', self.primary_key) self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int)
self.nullable = kwargs.pop('nullable', not self.primary_key) self.nullable = kwargs.pop('nullable', not self.primary_key)
self.default = kwargs.pop('default', None) self.default = kwargs.pop('default', None)
@ -37,11 +38,30 @@ class BaseField:
if self.pydantic_only and self.primary_key: if self.pydantic_only and self.primary_key:
raise ModelDefinitionError('Primary key column cannot be pydantic only.') raise ModelDefinitionError('Primary key column cannot be pydantic only.')
@property
def is_required(self):
return not self.nullable and not self.has_default and not self.is_auto_primary_key
@property
def default_value(self):
default = self.default if self.default is not None else self.server_default
return default() if callable(default) else default
@property
def has_default(self):
return self.default is not None or self.server_default is not None
@property
def is_auto_primary_key(self):
if self.primary_key:
return self.autoincrement
return False
def get_column(self, name: str = None) -> sqlalchemy.Column: def get_column(self, name: str = None) -> sqlalchemy.Column:
name = self.name or name self.name = self.name or name
constraints = self.get_constraints() constraints = self.get_constraints()
return sqlalchemy.Column( return sqlalchemy.Column(
name, self.name,
self.get_column_type(), self.get_column_type(),
*constraints, *constraints,
primary_key=self.primary_key, primary_key=self.primary_key,
@ -59,6 +79,9 @@ class BaseField:
def get_constraints(self) -> Optional[List]: def get_constraints(self) -> Optional[List]:
return [] return []
def expand_relationship(self, value, parent):
return value
class String(BaseField): class String(BaseField):
__type__ = str __type__ = str
@ -147,3 +170,37 @@ class Decimal(BaseField):
def get_column_type(self): def get_column_type(self):
return sqlalchemy.DECIMAL(self.length, self.precision) return sqlalchemy.DECIMAL(self.length, self.precision)
class ForeignKey(BaseField):
def __init__(self, to, related_name: str = None, nullable: bool = False):
super().__init__(nullable=nullable)
self.related_name = related_name
self.to = to
@property
def __type__(self):
return self.to.__pydantic_model__
def get_constraints(self):
fk_string = self.to.__tablename__ + "." + self.to.__pkname__
return [sqlalchemy.schema.ForeignKey(fk_string)]
def get_column_type(self):
to_column = self.to.__model_fields__[self.to.__pkname__]
return to_column.get_column_type()
def expand_relationship(self, value, child):
if isinstance(value, self.to):
model = value
else:
model = self.to(**{self.to.__pkname__: value})
child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
model._orm_relationship_manager.add(
Relationship(name=child_model_name, child=child, parent=model, fk_side='child'))
model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name,
type_=child.__pydantic_model__,
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__)
return model

View File

@ -1,26 +1,250 @@
import copy
import inspect
import json import json
from typing import Any, Type import uuid
from typing import Any, List, Type
from typing import Set, Dict from typing import Set, Dict
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig, create_model from pydantic import BaseConfig, create_model
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch
from orm.fields import BaseField from orm.fields import BaseField
from orm.relations import RelationshipManager
def parse_pydantic_field_from_model_fields(object_dict: dict): def parse_pydantic_field_from_model_fields(object_dict: dict):
pydantic_fields = {field_name: ( pydantic_fields = {field_name: (
base_field.__type__, base_field.__type__,
... if (not base_field.nullable and not base_field.default and not base_field.primary_key) else ( ... if base_field.is_required else base_field.default_value
base_field.default() if callable(base_field.default) else base_field.default)
) )
for field_name, base_field in object_dict.items() for field_name, base_field in object_dict.items()
if isinstance(base_field, BaseField)} if isinstance(base_field, BaseField)}
return pydantic_fields return pydantic_fields
FILTER_OPERATORS = {
"exact": "__eq__",
"iexact": "ilike",
"contains": "like",
"icontains": "ilike",
"in": "in_",
"gt": "__gt__",
"gte": "__ge__",
"lt": "__lt__",
"lte": "__le__",
}
class QuerySet:
ESCAPE_CHARACTERS = ['%', '_']
def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None,
limit_count: int = None, offset: int = None):
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count
self.query_offset = offset
def __get__(self, instance, owner):
return self.__class__(model_cls=owner)
@property
def database(self):
return self.model_cls.__database__
@property
def table(self):
return self.model_cls.__table__
def build_select_expression(self):
tables = [self.table]
select_from = self.table
for item in self._select_related:
model_cls = self.model_cls
select_from = self.table
for part in item.split("__"):
model_cls = model_cls.__model_fields__[part].to
select_from = sqlalchemy.sql.join(select_from, model_cls.__table__)
tables.append(model_cls.__table__)
expr = sqlalchemy.sql.select(tables)
expr = expr.select_from(select_from)
if self.filter_clauses:
if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0]
else:
clause = sqlalchemy.sql.and_(*self.filter_clauses)
expr = expr.where(clause)
if self.limit_count:
expr = expr.limit(self.limit_count)
if self.query_offset:
expr = expr.offset(self.query_offset)
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return expr
def filter(self, **kwargs):
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
if kwargs.get("pk"):
pk_name = self.model_cls.__pkname__
kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items():
if "__" in key:
parts = key.split("__")
# Determine if we should treat the final part as a
# filter operator or as a related field.
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
model_cls = self.model_cls
if related_parts:
# Add any implied select_related
related_str = "__".join(related_parts)
if related_str not in select_related:
select_related.append(related_str)
# Walk the relationships to the actual model class
# against which the comparison is being made.
for part in related_parts:
model_cls = model_cls.__model_fields__[part].to
column = model_cls.__table__.columns[field_name]
else:
op = "exact"
column = self.table.columns[key]
# Map the operation code onto SQLAlchemy's ColumnElement
# https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
op_attr = FILTER_OPERATORS[op]
has_escaped_character = False
if op in ["contains", "icontains"]:
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS
if c in value)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f'\\{char}')
value = f"%{value}%"
if isinstance(value, Model):
value = value.pk
clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None
filter_clauses.append(clause)
return self.__class__(
model_cls=self.model_cls,
filter_clauses=filter_clauses,
select_related=select_related,
limit_count=self.limit_count,
offset=self.query_offset
)
def select_related(self, related):
if not isinstance(related, (list, tuple)):
related = [related]
related = list(self._select_related) + related
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=related,
limit_count=self.limit_count,
offset=self.query_offset
)
# async def exists(self) -> bool:
# expr = self.build_select_expression()
# expr = sqlalchemy.exists(expr).select()
# return await self.database.fetch_val(expr)
def limit(self, limit_count: int):
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=limit_count,
offset=self.query_offset
)
def offset(self, offset: int):
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=offset
)
async def get(self, **kwargs):
if kwargs:
return await self.filter(**kwargs).get()
expr = self.build_select_expression().limit(2)
rows = await self.database.fetch_all(expr)
if not rows:
raise NoMatch()
if len(rows) > 1:
raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related)
async def all(self, **kwargs):
if kwargs:
return await self.filter(**kwargs).all()
expr = self.build_select_expression()
rows = await self.database.fetch_all(expr)
return [
self.model_cls.from_row(row, select_related=self._select_related)
for row in rows
]
async def create(self, **kwargs):
new_kwargs = dict(**kwargs)
# Remove primary key when None to prevent not null constraint in postgresql.
pkname = self.model_cls.__pkname__
pk = self.model_cls.__model_fields__[pkname]
if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement):
del new_kwargs[pkname]
# substitute related models with their pk
for field in self.model_cls.extract_related_names():
if field in new_kwargs and new_kwargs.get(field) is not None:
new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__)
# Build the insert expression.
expr = self.table.insert()
expr = expr.values(**new_kwargs)
# Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs)
instance.pk = await self.database.execute(expr)
return instance
class ModelMetaclass(type): class ModelMetaclass(type):
def __new__( def __new__(
mcs: type, name: str, bases: Any, attrs: dict mcs: type, name: str, bases: Any, attrs: dict
@ -52,9 +276,7 @@ class ModelMetaclass(type):
attrs['__pkname__'] = pkname attrs['__pkname__'] = pkname
if not pkname: if not pkname:
raise ModelDefinitionError( raise ModelDefinitionError('Table has to have a primary key.')
'Table has to have a primary key.'
)
# pydantic model creation # pydantic model creation
pydantic_fields = parse_pydantic_field_from_model_fields(attrs) pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
@ -62,9 +284,9 @@ class ModelMetaclass(type):
pydantic_model = create_model(name, __config__=config, **pydantic_fields) pydantic_model = create_model(name, __config__=config, **pydantic_fields)
attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_fields__'] = pydantic_fields
attrs['__pydantic_model__'] = pydantic_model attrs['__pydantic_model__'] = pydantic_model
attrs['__fields__'] = pydantic_model.__fields__ attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__)
attrs['__signature__'] = pydantic_model.__signature__ attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__)
attrs['__annotations__'] = pydantic_model.__annotations__ attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__)
attrs['__model_fields__'] = model_fields attrs['__model_fields__'] = model_fields
@ -78,9 +300,17 @@ class ModelMetaclass(type):
class Model(metaclass=ModelMetaclass): class Model(metaclass=ModelMetaclass):
__abstract__ = True __abstract__ = True
def __init__(self, *args, **kwargs) -> None: objects = QuerySet()
def __init__(self, **kwargs) -> None:
self._orm_id = uuid.uuid4().hex
self._orm_saved = False
self._orm_relationship_manager = RelationshipManager(self)
self._orm_observers = []
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk") kwargs[self.__pkname__] = kwargs.pop("pk")
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
self.values = self.__pydantic_model__(**kwargs) self.values = self.__pydantic_model__(**kwargs)
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Any) -> None:
@ -90,14 +320,19 @@ class Model(metaclass=ModelMetaclass):
value = json.dumps(value) value = json.dumps(value)
except TypeError: # pragma no cover except TypeError: # pragma no cover
pass pass
value = self.__model_fields__[key].expand_relationship(value, self)
setattr(self.values, key, value) setattr(self.values, key, value)
else: else:
super().__setattr__(key, value) super().__setattr__(key, value)
def __getattribute__(self, key: str) -> Any: def __getattribute__(self, key: str) -> Any:
if key != '__fields__' and key in self.__fields__: if key != '__fields__' and key in self.__fields__:
item = getattr(self.values, key) if key in self._orm_relationship_manager:
if self.is_conversion_to_json_needed(key) and isinstance(item, str): parent_item = self._orm_relationship_manager.get(key)
return parent_item
item = getattr(self.values, key, None)
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
try: try:
item = json.loads(item) item = json.loads(item)
except TypeError: # pragma no cover except TypeError: # pragma no cover
@ -106,6 +341,45 @@ class Model(metaclass=ModelMetaclass):
return super().__getattribute__(key) return super().__getattribute__(key)
def __repr__(self): # pragma no cover
return self.values.__repr__()
# def attach(self, observer: 'Model'):
# if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]):
# self._orm_observers.append(observer)
#
# def detach(self, observer: 'Model'):
# for ind, obs in enumerate(self._orm_observers):
# if obs._orm_id == observer._orm_id:
# del self._orm_observers[ind]
# break
#
def notify(self):
for obs in self._orm_observers: # pragma no cover
obs.orm_update(self)
def orm_update(self, subject: 'Model') -> None: # pragma no cover
print('should be updated here')
@classmethod
def from_row(cls, row, select_related: List = None) -> 'Model':
item = {}
select_related = select_related or []
for related in select_related:
if "__" in related:
first_part, remainder = related.split("__", 1)
model_cls = cls.__model_fields__[first_part].to
item[first_part] = model_cls.from_row(row, select_related=[remainder])
else:
model_cls = cls.__model_fields__[related].to
item[related] = model_cls.from_row(row)
for column in cls.__table__.columns:
if column.name not in item:
item[column.name] = row[column]
return cls(**item)
def is_conversion_to_json_needed(self, column_name: str) -> bool: def is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json return self.__model_fields__.get(column_name).__type__ == pydantic.Json
@ -136,17 +410,20 @@ class Model(metaclass=ModelMetaclass):
@classmethod @classmethod
def extract_related_names(cls) -> Set: def extract_related_names(cls) -> Set:
related_names = set() related_names = set()
# for name, field in cls.__fields__.items(): for name, field in cls.__fields__.items():
# if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
# related_names.add(name) related_names.add(name)
# elif field.sub_fields and any( # elif field.sub_fields and any(
# [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]):
# related_names.add(name) # related_names.add(name)
return related_names return related_names
def extract_model_db_fields(self) -> Dict: def extract_model_db_fields(self) -> Dict:
self_fields = self.extract_own_model_fields() self_fields = self.extract_own_model_fields()
self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns}
for field in self.extract_related_names():
if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__)
return self_fields return self_fields
async def save(self) -> int: async def save(self) -> int:
@ -157,6 +434,7 @@ class Model(metaclass=ModelMetaclass):
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.__database__.execute(expr) item_id = await self.__database__.execute(expr)
setattr(self, 'pk', item_id) setattr(self, 'pk', item_id)
self.notify()
return item_id return item_id
async def update(self, **kwargs: Any) -> int: async def update(self, **kwargs: Any) -> int:
@ -169,16 +447,19 @@ class Model(metaclass=ModelMetaclass):
expr = self.__table__.update().values(**self_fields).where( expr = self.__table__.update().values(**self_fields).where(
self.pk_column == getattr(self, self.__pkname__)) self.pk_column == getattr(self, self.__pkname__))
result = await self.__database__.execute(expr) result = await self.__database__.execute(expr)
self.notify()
return result return result
async def delete(self) -> int: async def delete(self) -> int:
expr = self.__table__.delete() expr = self.__table__.delete()
expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) expr = expr.where(self.pk_column == (getattr(self, self.__pkname__)))
result = await self.__database__.execute(expr) result = await self.__database__.execute(expr)
self.notify()
return result return result
async def load(self) -> 'Model': async def load(self) -> 'Model':
expr = self.__table__.select().where(self.pk_column == self.pk) expr = self.__table__.select().where(self.pk_column == self.pk)
row = await self.__database__.fetch_one(expr) row = await self.__database__.fetch_one(expr)
self.from_dict(dict(row)) self.from_dict(dict(row))
self.notify()
return self return self

45
orm/relations.py Normal file
View File

@ -0,0 +1,45 @@
from typing import Dict, Union, List
from sqlalchemy import text
class Relationship:
def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'):
self.fk_side = fk_side
self.child = child
self.parent = parent
self.name = name
class RelationshipManager:
def __init__(self, model: 'Model'):
self._orm_id: str = model._orm_id
self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict()
def __contains__(self, item):
return item in self._relations
def add_related(self, relation: Relationship):
if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id:
new_relation = Relationship(name=relation.parent.__class__.__name__.lower(),
child=relation.parent,
parent=relation.child,
fk_side='parent')
relation.child._orm_relationship_manager.add(new_relation)
def add(self, relation: Relationship):
if relation.name in self._relations:
self._relations[relation.name].append(relation)
else:
self._relations[relation.name] = [relation]
self.add_related(relation)
def get(self, name: str):
for rel, relations in self._relations.items():
if rel == name:
if relations and relations[0].fk_side == 'parent':
return relations[0].child
else:
return [rela.child for rela in relations]

231
tests/test_foreign_keys.py Normal file
View File

@ -0,0 +1,231 @@
import databases
import pytest
import sqlalchemy
import orm
from orm.exceptions import NoMatch, MultipleMatches
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Album(orm.Model):
__tablename__ = "album"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
class Track(orm.Model):
__tablename__ = "track"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
album = orm.ForeignKey(Album)
title = orm.String(length=100)
position = orm.Integer()
class Organisation(orm.Model):
__tablename__ = "org"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
ident = orm.String(length=100)
class Team(orm.Model):
__tablename__ = "team"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
org = orm.ForeignKey(Organisation)
name = orm.String(length=100)
class Member(orm.Model):
__tablename__ = "member"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
team = orm.ForeignKey(Team)
email = orm.String(length=100)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_model_crud():
async with database:
album = Album(name="Malibu")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk
assert track.album.name is None
await track.album.load()
assert track.album.name == "Malibu"
assert len(album.tracks) == 3
assert album.tracks[1].title == "Heart don't stand a chance"
album1 = await Album.objects.get(name='Malibu')
assert album1.pk == 1
assert album1.tracks is None
@pytest.mark.asyncio
async def test_select_related():
async with database:
album = Album(name="Malibu")
await album.save()
track1 = Track(album=album, title="The Bird", position=1)
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
track3 = Track(album=album, title="The Waters", position=3)
await track1.save()
await track2.save()
await track3.save()
fantasies = Album(name="Fantasies")
await fantasies.save()
track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
track5 = Track(album=fantasies, title="Sick Muse", position=2)
track6 = Track(album=fantasies, title="Satellite Mind", position=3)
await track4.save()
await track5.save()
await track6.save()
track = await Track.objects.select_related("album").get(title="The Bird")
assert track.album.name == "Malibu"
tracks = await Track.objects.select_related("album").all()
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_fk_filter():
async with database:
malibu = Album(name="Malibu%")
await malibu.save()
await Track.objects.create(album=malibu, title="The Bird", position=1)
await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2)
await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies")
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
tracks = await Track.objects.select_related("album").filter(album__name="Fantasies").all()
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.select_related("album").filter(album__name__icontains="fan").all()
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="fan").all()
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Fantasies"
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
assert len(tracks) == 3
tracks = await Track.objects.filter(album=malibu).select_related("album").all()
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
tracks = await Track.objects.select_related("album").all(album=malibu)
assert len(tracks) == 3
for track in tracks:
assert track.album.name == "Malibu%"
@pytest.mark.asyncio
async def test_multiple_fk():
async with database:
acme = await Organisation.objects.create(ident="ACME Ltd")
red_team = await Team.objects.create(org=acme, name="Red Team")
blue_team = await Team.objects.create(org=acme, name="Blue Team")
await Member.objects.create(team=red_team, email="a@example.org")
await Member.objects.create(team=red_team, email="b@example.org")
await Member.objects.create(team=blue_team, email="c@example.org")
await Member.objects.create(team=blue_team, email="d@example.org")
other = await Organisation.objects.create(ident="Other ltd")
team = await Team.objects.create(org=other, name="Green Team")
await Member.objects.create(team=team, email="e@example.org")
members = await Member.objects.select_related('team__org').filter(team__org__ident="ACME Ltd").all()
assert len(members) == 4
for member in members:
assert member.team.org.ident == "ACME Ltd"
@pytest.mark.asyncio
async def test_pk_filter():
async with database:
fantasies = await Album.objects.create(name="Test")
await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
tracks = await Track.objects.select_related("album").filter(pk=1).all()
assert len(tracks) == 1
tracks = await Track.objects.select_related("album").filter(position=2, album__name='Test').all()
assert len(tracks) == 1
@pytest.mark.asyncio
async def test_limit_and_offset():
async with database:
fantasies = await Album.objects.create(name="Limitless")
await Track.objects.create(id=None, album=fantasies, title="Sample", position=1)
await Track.objects.create(album=fantasies, title="Sample2", position=2)
await Track.objects.create(album=fantasies, title="Sample3", position=3)
tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample"
tracks = await Track.objects.limit(1).offset(1).all()
assert len(tracks) == 1
assert tracks[0].title == "Sample2"
@pytest.mark.asyncio
async def test_get_exceptions():
async with database:
fantasies = await Album.objects.create(name="Test")
with pytest.raises(NoMatch):
await Album.objects.get(name="Test2")
await Track.objects.create(album=fantasies, title="Test1", position=1)
await Track.objects.create(album=fantasies, title="Test2", position=2)
await Track.objects.create(album=fantasies, title="Test3", position=3)
with pytest.raises(MultipleMatches):
await Track.objects.select_related("album").get(album=fantasies)

0
tests/test_models.py Normal file
View File