From 345fd227d1076958426cd4a32c327de601350834 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 4 Aug 2020 18:44:17 +0200 Subject: [PATCH] sloppy work on passing all of the test and reimplementing most of the features from encode --- .coverage | Bin 53248 -> 53248 bytes README.md | 4 + orm/__init__.py | 4 +- orm/exceptions.py | 6 +- orm/fields.py | 63 +++++++- orm/models.py | 319 ++++++++++++++++++++++++++++++++++--- orm/relations.py | 45 ++++++ tests/test_foreign_keys.py | 231 +++++++++++++++++++++++++++ tests/test_models.py | 0 9 files changed, 648 insertions(+), 24 deletions(-) create mode 100644 orm/relations.py create mode 100644 tests/test_foreign_keys.py create mode 100644 tests/test_models.py diff --git a/.coverage b/.coverage index 0c9f6db6adb3236f1925954b21a689c4260c5301..ba7b5f4a6b067d03e366db7864d285fedeb1619e 100644 GIT binary patch delta 742 zcmZ9K(MwZN9LMiGyLY?y?w+5fMNL5|aJ5h?CEa>hJ=jAJE|MNZ(j+z+&aF1hOltNQ z2;B!Cl;(qxPZEtaB5PE>I8cch6n35Zuu9rSE*xW%+quh$^>ohf`~7{-`TXEFI)z53 z(DO45TvN^27IRah`CN0r+z=_P(g_wG9k7T)m1>t9fj#rB5-4>L9kh%wu!*3^?4bLWuqw(u=LuFyY>$U<3@% zE^OizoQF?%2A*OsHe;S}zZb+iT_J?c>vDfVgk7BqNJEvLCA)W5I%Ps^T0okf$; z9iY{Ua+iVD7*U5}|4Ym4h4g{e?r@Eshe`^vum$T7h9Jy<4Q_x5PCzAr-NinJSF2de zmsD`%G;(*xWq)=u9gD|J7K=r5oF(dIHW^jm1TX|zdGdKo-|or%U|<-FWzQ4 zV(qG%SGQxcl9b&2p7Nzq8N@T>t)t8pzsnbD(9)zhMBDB5YBjlCRJ#H`DUp;ivbebB zJU_8$4}5X%Or(;tJ}pvO(UzR4f~a_{}v{I+$0YJMasB!Z+9iiNvF@1Pd@v!qWhVtfhd+HVTL=pCCq;>;3>4 Cc?atN delta 483 zcmZozz}&Ead4qvIm$8D8ft7)&m5JeIEB#XfVw`*h4E$I4m+)8d2l1=%{o}jDH2nT;X1OHF{SNsq7kMZx}pUXd$zaD65I)A+g2MZ&o7Dw6i|MP#&`F#J`GfR7W zdv-Q+l*d$nh zEF+e@oBQ9D{kNaF_hydS?6YZ`Z@vKrA4m=JxBra1OhA?l)2_37?CtNrXJut%;pCkB zub&O*x>F4NzxluLzvX|%f1m#b|3&^&K-X>J-+XGmm4X;2|7DQ$H~x?OZ}^|{KjOa& PRB{=pU?>0N%jb0g#WkQw diff --git a/README.md b/README.md index 6145f7f..5d99739 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,7 @@ The following keyword arguments are supported on all field types. * `primary_key` * `nullable` * `default` +* `server_default` * `index` * `unique` @@ -165,6 +166,9 @@ All fields are required unless one of the following is set: * `nullable` - Creates a nullable column. Sets the default to `None`. * `default` - Set a default value for the field. +* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). +* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. +Autoincrement is set by default on int primary keys. Available Model Fields: * `orm.String(length)` diff --git a/orm/__init__.py b/orm/__init__.py index 5270355..9c652ae 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -1,4 +1,5 @@ -from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float +from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float, \ + ForeignKey from orm.models import Model __all__ = [ @@ -13,5 +14,6 @@ __all__ = [ "Date", "Decimal", "Float", + "ForeignKey", "Model" ] diff --git a/orm/exceptions.py b/orm/exceptions.py index 7321d99..1a8c6d0 100644 --- a/orm/exceptions.py +++ b/orm/exceptions.py @@ -10,7 +10,11 @@ class ModelNotSet(AsyncOrmException): pass -class MultipleResults(AsyncOrmException): +class NoMatch(AsyncOrmException): + pass + + +class MultipleMatches(AsyncOrmException): pass diff --git a/orm/fields.py b/orm/fields.py index b42a901..4393bd1 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -6,6 +6,7 @@ import pydantic import sqlalchemy from orm.exceptions import ModelDefinitionError +from orm.relations import Relationship class BaseField: @@ -24,7 +25,7 @@ class BaseField: self.name = name self.primary_key = kwargs.pop('primary_key', False) - self.autoincrement = kwargs.pop('autoincrement', self.primary_key) + self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int) self.nullable = kwargs.pop('nullable', not self.primary_key) self.default = kwargs.pop('default', None) @@ -37,11 +38,30 @@ class BaseField: if self.pydantic_only and self.primary_key: raise ModelDefinitionError('Primary key column cannot be pydantic only.') + @property + def is_required(self): + return not self.nullable and not self.has_default and not self.is_auto_primary_key + + @property + def default_value(self): + default = self.default if self.default is not None else self.server_default + return default() if callable(default) else default + + @property + def has_default(self): + return self.default is not None or self.server_default is not None + + @property + def is_auto_primary_key(self): + if self.primary_key: + return self.autoincrement + return False + def get_column(self, name: str = None) -> sqlalchemy.Column: - name = self.name or name + self.name = self.name or name constraints = self.get_constraints() return sqlalchemy.Column( - name, + self.name, self.get_column_type(), *constraints, primary_key=self.primary_key, @@ -59,6 +79,9 @@ class BaseField: def get_constraints(self) -> Optional[List]: return [] + def expand_relationship(self, value, parent): + return value + class String(BaseField): __type__ = str @@ -147,3 +170,37 @@ class Decimal(BaseField): def get_column_type(self): return sqlalchemy.DECIMAL(self.length, self.precision) + + +class ForeignKey(BaseField): + def __init__(self, to, related_name: str = None, nullable: bool = False): + super().__init__(nullable=nullable) + self.related_name = related_name + self.to = to + + @property + def __type__(self): + return self.to.__pydantic_model__ + + def get_constraints(self): + fk_string = self.to.__tablename__ + "." + self.to.__pkname__ + return [sqlalchemy.schema.ForeignKey(fk_string)] + + def get_column_type(self): + to_column = self.to.__model_fields__[self.to.__pkname__] + return to_column.get_column_type() + + def expand_relationship(self, value, child): + if isinstance(value, self.to): + model = value + else: + model = self.to(**{self.to.__pkname__: value}) + + child_model_name = self.related_name or child.__class__.__name__.lower() + 's' + model._orm_relationship_manager.add( + Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) + model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name, + type_=child.__pydantic_model__, + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__) + return model diff --git a/orm/models.py b/orm/models.py index 096415e..e19474c 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,26 +1,250 @@ +import copy +import inspect import json -from typing import Any, Type +import uuid +from typing import Any, List, Type from typing import Set, Dict import pydantic import sqlalchemy from pydantic import BaseConfig, create_model -from orm.exceptions import ModelDefinitionError +from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch from orm.fields import BaseField +from orm.relations import RelationshipManager def parse_pydantic_field_from_model_fields(object_dict: dict): pydantic_fields = {field_name: ( base_field.__type__, - ... if (not base_field.nullable and not base_field.default and not base_field.primary_key) else ( - base_field.default() if callable(base_field.default) else base_field.default) + ... if base_field.is_required else base_field.default_value ) for field_name, base_field in object_dict.items() if isinstance(base_field, BaseField)} return pydantic_fields +FILTER_OPERATORS = { + "exact": "__eq__", + "iexact": "ilike", + "contains": "like", + "icontains": "ilike", + "in": "in_", + "gt": "__gt__", + "gte": "__ge__", + "lt": "__lt__", + "lte": "__le__", +} + + +class QuerySet: + ESCAPE_CHARACTERS = ['%', '_'] + + def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, + limit_count: int = None, offset: int = None): + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + + def __get__(self, instance, owner): + return self.__class__(model_cls=owner) + + @property + def database(self): + return self.model_cls.__database__ + + @property + def table(self): + return self.model_cls.__table__ + + def build_select_expression(self): + tables = [self.table] + select_from = self.table + + for item in self._select_related: + model_cls = self.model_cls + select_from = self.table + for part in item.split("__"): + model_cls = model_cls.__model_fields__[part].to + select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) + tables.append(model_cls.__table__) + + expr = sqlalchemy.sql.select(tables) + expr = expr.select_from(select_from) + + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + # print(expr.compile(compile_kwargs={"literal_binds": True})) + return expr + + def filter(self, **kwargs): + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + + if kwargs.get("pk"): + pk_name = self.model_cls.__pkname__ + kwargs[pk_name] = kwargs.pop("pk") + + for key, value in kwargs.items(): + if "__" in key: + parts = key.split("__") + + # Determine if we should treat the final part as a + # filter operator or as a related field. + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + model_cls = self.model_cls + if related_parts: + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + for part in related_parts: + model_cls = model_cls.__model_fields__[part].to + + column = model_cls.__table__.columns[field_name] + + else: + op = "exact" + column = self.table.columns[key] + + # Map the operation code onto SQLAlchemy's ColumnElement + # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement + op_attr = FILTER_OPERATORS[op] + has_escaped_character = False + + if op in ["contains", "icontains"]: + has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS + if c in value) + if has_escaped_character: + # enable escape modifier + for char in self.ESCAPE_CHARACTERS: + value = value.replace(char, f'\\{char}') + value = f"%{value}%" + + if isinstance(value, Model): + value = value.pk + + clause = getattr(column, op_attr)(value) + clause.modifiers['escape'] = '\\' if has_escaped_character else None + filter_clauses.append(clause) + + return self.__class__( + model_cls=self.model_cls, + filter_clauses=filter_clauses, + select_related=select_related, + limit_count=self.limit_count, + offset=self.query_offset + ) + + def select_related(self, related): + if not isinstance(related, (list, tuple)): + related = [related] + + related = list(self._select_related) + related + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=related, + limit_count=self.limit_count, + offset=self.query_offset + ) + + # async def exists(self) -> bool: + # expr = self.build_select_expression() + # expr = sqlalchemy.exists(expr).select() + # return await self.database.fetch_val(expr) + + def limit(self, limit_count: int): + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=limit_count, + offset=self.query_offset + ) + + def offset(self, offset: int): + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=offset + ) + + async def get(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).get() + + expr = self.build_select_expression().limit(2) + rows = await self.database.fetch_all(expr) + + if not rows: + raise NoMatch() + if len(rows) > 1: + raise MultipleMatches() + return self.model_cls.from_row(rows[0], select_related=self._select_related) + + async def all(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).all() + + expr = self.build_select_expression() + rows = await self.database.fetch_all(expr) + return [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + + async def create(self, **kwargs): + + new_kwargs = dict(**kwargs) + + # Remove primary key when None to prevent not null constraint in postgresql. + pkname = self.model_cls.__pkname__ + pk = self.model_cls.__model_fields__[pkname] + if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): + del new_kwargs[pkname] + + # substitute related models with their pk + for field in self.model_cls.extract_related_names(): + if field in new_kwargs and new_kwargs.get(field) is not None: + new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) + + # Build the insert expression. + expr = self.table.insert() + expr = expr.values(**new_kwargs) + + # Execute the insert, and return a new model instance. + instance = self.model_cls(**kwargs) + instance.pk = await self.database.execute(expr) + return instance + + class ModelMetaclass(type): def __new__( mcs: type, name: str, bases: Any, attrs: dict @@ -52,9 +276,7 @@ class ModelMetaclass(type): attrs['__pkname__'] = pkname if not pkname: - raise ModelDefinitionError( - 'Table has to have a primary key.' - ) + raise ModelDefinitionError('Table has to have a primary key.') # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) @@ -62,9 +284,9 @@ class ModelMetaclass(type): pydantic_model = create_model(name, __config__=config, **pydantic_fields) attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_model__'] = pydantic_model - attrs['__fields__'] = pydantic_model.__fields__ - attrs['__signature__'] = pydantic_model.__signature__ - attrs['__annotations__'] = pydantic_model.__annotations__ + attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) + attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) + attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) attrs['__model_fields__'] = model_fields @@ -78,9 +300,17 @@ class ModelMetaclass(type): class Model(metaclass=ModelMetaclass): __abstract__ = True - def __init__(self, *args, **kwargs) -> None: + objects = QuerySet() + + def __init__(self, **kwargs) -> None: + self._orm_id = uuid.uuid4().hex + self._orm_saved = False + self._orm_relationship_manager = RelationshipManager(self) + self._orm_observers = [] + if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") + kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} self.values = self.__pydantic_model__(**kwargs) def __setattr__(self, key: str, value: Any) -> None: @@ -90,14 +320,19 @@ class Model(metaclass=ModelMetaclass): value = json.dumps(value) except TypeError: # pragma no cover pass + value = self.__model_fields__[key].expand_relationship(value, self) setattr(self.values, key, value) else: super().__setattr__(key, value) def __getattribute__(self, key: str) -> Any: if key != '__fields__' and key in self.__fields__: - item = getattr(self.values, key) - if self.is_conversion_to_json_needed(key) and isinstance(item, str): + if key in self._orm_relationship_manager: + parent_item = self._orm_relationship_manager.get(key) + return parent_item + + item = getattr(self.values, key, None) + if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): try: item = json.loads(item) except TypeError: # pragma no cover @@ -106,6 +341,45 @@ class Model(metaclass=ModelMetaclass): return super().__getattribute__(key) + def __repr__(self): # pragma no cover + return self.values.__repr__() + + # def attach(self, observer: 'Model'): + # if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]): + # self._orm_observers.append(observer) + # + # def detach(self, observer: 'Model'): + # for ind, obs in enumerate(self._orm_observers): + # if obs._orm_id == observer._orm_id: + # del self._orm_observers[ind] + # break + # + def notify(self): + for obs in self._orm_observers: # pragma no cover + obs.orm_update(self) + + def orm_update(self, subject: 'Model') -> None: # pragma no cover + print('should be updated here') + + @classmethod + def from_row(cls, row, select_related: List = None) -> 'Model': + item = {} + select_related = select_related or [] + for related in select_related: + if "__" in related: + first_part, remainder = related.split("__", 1) + model_cls = cls.__model_fields__[first_part].to + item[first_part] = model_cls.from_row(row, select_related=[remainder]) + else: + model_cls = cls.__model_fields__[related].to + item[related] = model_cls.from_row(row) + + for column in cls.__table__.columns: + if column.name not in item: + item[column.name] = row[column] + + return cls(**item) + def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @@ -136,17 +410,20 @@ class Model(metaclass=ModelMetaclass): @classmethod def extract_related_names(cls) -> Set: related_names = set() - # for name, field in cls.__fields__.items(): - # if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): - # related_names.add(name) - # elif field.sub_fields and any( - # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): - # related_names.add(name) + for name, field in cls.__fields__.items(): + if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): + related_names.add(name) + # elif field.sub_fields and any( + # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): + # related_names.add(name) return related_names def extract_model_db_fields(self) -> Dict: self_fields = self.extract_own_model_fields() self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} + for field in self.extract_related_names(): + if getattr(self, field) is not None: + self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__) return self_fields async def save(self) -> int: @@ -157,6 +434,7 @@ class Model(metaclass=ModelMetaclass): expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) setattr(self, 'pk', item_id) + self.notify() return item_id async def update(self, **kwargs: Any) -> int: @@ -169,16 +447,19 @@ class Model(metaclass=ModelMetaclass): expr = self.__table__.update().values(**self_fields).where( self.pk_column == getattr(self, self.__pkname__)) result = await self.__database__.execute(expr) + self.notify() return result async def delete(self) -> int: expr = self.__table__.delete() expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) result = await self.__database__.execute(expr) + self.notify() return result async def load(self) -> 'Model': expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) + self.notify() return self diff --git a/orm/relations.py b/orm/relations.py new file mode 100644 index 0000000..1bd70cf --- /dev/null +++ b/orm/relations.py @@ -0,0 +1,45 @@ +from typing import Dict, Union, List + +from sqlalchemy import text + + +class Relationship: + + def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'): + self.fk_side = fk_side + self.child = child + self.parent = parent + self.name = name + + +class RelationshipManager: + + def __init__(self, model: 'Model'): + self._orm_id: str = model._orm_id + self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict() + + def __contains__(self, item): + return item in self._relations + + def add_related(self, relation: Relationship): + if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id: + new_relation = Relationship(name=relation.parent.__class__.__name__.lower(), + child=relation.parent, + parent=relation.child, + fk_side='parent') + relation.child._orm_relationship_manager.add(new_relation) + + def add(self, relation: Relationship): + if relation.name in self._relations: + self._relations[relation.name].append(relation) + else: + self._relations[relation.name] = [relation] + self.add_related(relation) + + def get(self, name: str): + for rel, relations in self._relations.items(): + if rel == name: + if relations and relations[0].fk_side == 'parent': + return relations[0].child + else: + return [rela.child for rela in relations] diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py new file mode 100644 index 0000000..d3d95bf --- /dev/null +++ b/tests/test_foreign_keys.py @@ -0,0 +1,231 @@ +import databases +import pytest +import sqlalchemy + +import orm +from orm.exceptions import NoMatch, MultipleMatches +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Album(orm.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Track(orm.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + album = orm.ForeignKey(Album) + title = orm.String(length=100) + position = orm.Integer() + + +class Organisation(orm.Model): + __tablename__ = "org" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + ident = orm.String(length=100) + + +class Team(orm.Model): + __tablename__ = "team" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + org = orm.ForeignKey(Organisation) + name = orm.String(length=100) + + +class Member(orm.Model): + __tablename__ = "member" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + team = orm.ForeignKey(Team) + email = orm.String(length=100) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() + + track = await Track.objects.get(title="The Bird") + assert track.album.pk == album.pk + assert track.album.name is None + await track.album.load() + assert track.album.name == "Malibu" + + assert len(album.tracks) == 3 + assert album.tracks[1].title == "Heart don't stand a chance" + + album1 = await Album.objects.get(name='Malibu') + assert album1.pk == 1 + assert album1.tracks is None + + +@pytest.mark.asyncio +async def test_select_related(): + async with database: + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() + + fantasies = Album(name="Fantasies") + await fantasies.save() + track4 = Track(album=fantasies, title="Help I'm Alive", position=1) + track5 = Track(album=fantasies, title="Sick Muse", position=2) + track6 = Track(album=fantasies, title="Satellite Mind", position=3) + await track4.save() + await track5.save() + await track6.save() + + track = await Track.objects.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" + + tracks = await Track.objects.select_related("album").all() + assert len(tracks) == 6 + + +@pytest.mark.asyncio +async def test_fk_filter(): + async with database: + malibu = Album(name="Malibu%") + await malibu.save() + await Track.objects.create(album=malibu, title="The Bird", position=1) + await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) + await Track.objects.create(album=malibu, title="The Waters", position=3) + + fantasies = await Album.objects.create(name="Fantasies") + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + + tracks = await Track.objects.select_related("album").filter(album__name="Fantasies").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.select_related("album").filter(album__name__icontains="fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__contains="fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__contains="Malibu%").all() + assert len(tracks) == 3 + + tracks = await Track.objects.filter(album=malibu).select_related("album").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" + + tracks = await Track.objects.select_related("album").all(album=malibu) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" + + +@pytest.mark.asyncio +async def test_multiple_fk(): + async with database: + acme = await Organisation.objects.create(ident="ACME Ltd") + red_team = await Team.objects.create(org=acme, name="Red Team") + blue_team = await Team.objects.create(org=acme, name="Blue Team") + await Member.objects.create(team=red_team, email="a@example.org") + await Member.objects.create(team=red_team, email="b@example.org") + await Member.objects.create(team=blue_team, email="c@example.org") + await Member.objects.create(team=blue_team, email="d@example.org") + + other = await Organisation.objects.create(ident="Other ltd") + team = await Team.objects.create(org=other, name="Green Team") + await Member.objects.create(team=team, email="e@example.org") + + members = await Member.objects.select_related('team__org').filter(team__org__ident="ACME Ltd").all() + assert len(members) == 4 + for member in members: + assert member.team.org.ident == "ACME Ltd" + + +@pytest.mark.asyncio +async def test_pk_filter(): + async with database: + fantasies = await Album.objects.create(name="Test") + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + tracks = await Track.objects.select_related("album").filter(pk=1).all() + assert len(tracks) == 1 + + tracks = await Track.objects.select_related("album").filter(position=2, album__name='Test').all() + assert len(tracks) == 1 + + +@pytest.mark.asyncio +async def test_limit_and_offset(): + async with database: + fantasies = await Album.objects.create(name="Limitless") + await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) + await Track.objects.create(album=fantasies, title="Sample2", position=2) + await Track.objects.create(album=fantasies, title="Sample3", position=3) + + tracks = await Track.objects.limit(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample" + + tracks = await Track.objects.limit(1).offset(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample2" + + +@pytest.mark.asyncio +async def test_get_exceptions(): + async with database: + fantasies = await Album.objects.create(name="Test") + + with pytest.raises(NoMatch): + await Album.objects.get(name="Test2") + + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + with pytest.raises(MultipleMatches): + await Track.objects.select_related("album").get(album=fantasies) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..e69de29