sloppy work on passing all of the test and reimplementing most of the features from encode
This commit is contained in:
@ -158,6 +158,7 @@ The following keyword arguments are supported on all field types.
|
||||
* `primary_key`
|
||||
* `nullable`
|
||||
* `default`
|
||||
* `server_default`
|
||||
* `index`
|
||||
* `unique`
|
||||
|
||||
@ -165,6 +166,9 @@ All fields are required unless one of the following is set:
|
||||
|
||||
* `nullable` - Creates a nullable column. Sets the default to `None`.
|
||||
* `default` - Set a default value for the field.
|
||||
* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`).
|
||||
* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column.
|
||||
Autoincrement is set by default on int primary keys.
|
||||
|
||||
Available Model Fields:
|
||||
* `orm.String(length)`
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float
|
||||
from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float, \
|
||||
ForeignKey
|
||||
from orm.models import Model
|
||||
|
||||
__all__ = [
|
||||
@ -13,5 +14,6 @@ __all__ = [
|
||||
"Date",
|
||||
"Decimal",
|
||||
"Float",
|
||||
"ForeignKey",
|
||||
"Model"
|
||||
]
|
||||
|
||||
@ -10,7 +10,11 @@ class ModelNotSet(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
class MultipleResults(AsyncOrmException):
|
||||
class NoMatch(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
class MultipleMatches(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import pydantic
|
||||
import sqlalchemy
|
||||
|
||||
from orm.exceptions import ModelDefinitionError
|
||||
from orm.relations import Relationship
|
||||
|
||||
|
||||
class BaseField:
|
||||
@ -24,7 +25,7 @@ class BaseField:
|
||||
|
||||
self.name = name
|
||||
self.primary_key = kwargs.pop('primary_key', False)
|
||||
self.autoincrement = kwargs.pop('autoincrement', self.primary_key)
|
||||
self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int)
|
||||
|
||||
self.nullable = kwargs.pop('nullable', not self.primary_key)
|
||||
self.default = kwargs.pop('default', None)
|
||||
@ -37,11 +38,30 @@ class BaseField:
|
||||
if self.pydantic_only and self.primary_key:
|
||||
raise ModelDefinitionError('Primary key column cannot be pydantic only.')
|
||||
|
||||
@property
|
||||
def is_required(self):
|
||||
return not self.nullable and not self.has_default and not self.is_auto_primary_key
|
||||
|
||||
@property
|
||||
def default_value(self):
|
||||
default = self.default if self.default is not None else self.server_default
|
||||
return default() if callable(default) else default
|
||||
|
||||
@property
|
||||
def has_default(self):
|
||||
return self.default is not None or self.server_default is not None
|
||||
|
||||
@property
|
||||
def is_auto_primary_key(self):
|
||||
if self.primary_key:
|
||||
return self.autoincrement
|
||||
return False
|
||||
|
||||
def get_column(self, name: str = None) -> sqlalchemy.Column:
|
||||
name = self.name or name
|
||||
self.name = self.name or name
|
||||
constraints = self.get_constraints()
|
||||
return sqlalchemy.Column(
|
||||
name,
|
||||
self.name,
|
||||
self.get_column_type(),
|
||||
*constraints,
|
||||
primary_key=self.primary_key,
|
||||
@ -59,6 +79,9 @@ class BaseField:
|
||||
def get_constraints(self) -> Optional[List]:
|
||||
return []
|
||||
|
||||
def expand_relationship(self, value, parent):
|
||||
return value
|
||||
|
||||
|
||||
class String(BaseField):
|
||||
__type__ = str
|
||||
@ -147,3 +170,37 @@ class Decimal(BaseField):
|
||||
|
||||
def get_column_type(self):
|
||||
return sqlalchemy.DECIMAL(self.length, self.precision)
|
||||
|
||||
|
||||
class ForeignKey(BaseField):
|
||||
def __init__(self, to, related_name: str = None, nullable: bool = False):
|
||||
super().__init__(nullable=nullable)
|
||||
self.related_name = related_name
|
||||
self.to = to
|
||||
|
||||
@property
|
||||
def __type__(self):
|
||||
return self.to.__pydantic_model__
|
||||
|
||||
def get_constraints(self):
|
||||
fk_string = self.to.__tablename__ + "." + self.to.__pkname__
|
||||
return [sqlalchemy.schema.ForeignKey(fk_string)]
|
||||
|
||||
def get_column_type(self):
|
||||
to_column = self.to.__model_fields__[self.to.__pkname__]
|
||||
return to_column.get_column_type()
|
||||
|
||||
def expand_relationship(self, value, child):
|
||||
if isinstance(value, self.to):
|
||||
model = value
|
||||
else:
|
||||
model = self.to(**{self.to.__pkname__: value})
|
||||
|
||||
child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
|
||||
model._orm_relationship_manager.add(
|
||||
Relationship(name=child_model_name, child=child, parent=model, fk_side='child'))
|
||||
model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name,
|
||||
type_=child.__pydantic_model__,
|
||||
model_config=child.__pydantic_model__.__config__,
|
||||
class_validators=child.__pydantic_model__.__validators__)
|
||||
return model
|
||||
|
||||
319
orm/models.py
319
orm/models.py
@ -1,26 +1,250 @@
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Type
|
||||
import uuid
|
||||
from typing import Any, List, Type
|
||||
from typing import Set, Dict
|
||||
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import BaseConfig, create_model
|
||||
|
||||
from orm.exceptions import ModelDefinitionError
|
||||
from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch
|
||||
from orm.fields import BaseField
|
||||
from orm.relations import RelationshipManager
|
||||
|
||||
|
||||
def parse_pydantic_field_from_model_fields(object_dict: dict):
|
||||
pydantic_fields = {field_name: (
|
||||
base_field.__type__,
|
||||
... if (not base_field.nullable and not base_field.default and not base_field.primary_key) else (
|
||||
base_field.default() if callable(base_field.default) else base_field.default)
|
||||
... if base_field.is_required else base_field.default_value
|
||||
)
|
||||
for field_name, base_field in object_dict.items()
|
||||
if isinstance(base_field, BaseField)}
|
||||
return pydantic_fields
|
||||
|
||||
|
||||
FILTER_OPERATORS = {
|
||||
"exact": "__eq__",
|
||||
"iexact": "ilike",
|
||||
"contains": "like",
|
||||
"icontains": "ilike",
|
||||
"in": "in_",
|
||||
"gt": "__gt__",
|
||||
"gte": "__ge__",
|
||||
"lt": "__lt__",
|
||||
"lte": "__le__",
|
||||
}
|
||||
|
||||
|
||||
class QuerySet:
|
||||
ESCAPE_CHARACTERS = ['%', '_']
|
||||
|
||||
def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None,
|
||||
limit_count: int = None, offset: int = None):
|
||||
self.model_cls = model_cls
|
||||
self.filter_clauses = [] if filter_clauses is None else filter_clauses
|
||||
self._select_related = [] if select_related is None else select_related
|
||||
self.limit_count = limit_count
|
||||
self.query_offset = offset
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.__class__(model_cls=owner)
|
||||
|
||||
@property
|
||||
def database(self):
|
||||
return self.model_cls.__database__
|
||||
|
||||
@property
|
||||
def table(self):
|
||||
return self.model_cls.__table__
|
||||
|
||||
def build_select_expression(self):
|
||||
tables = [self.table]
|
||||
select_from = self.table
|
||||
|
||||
for item in self._select_related:
|
||||
model_cls = self.model_cls
|
||||
select_from = self.table
|
||||
for part in item.split("__"):
|
||||
model_cls = model_cls.__model_fields__[part].to
|
||||
select_from = sqlalchemy.sql.join(select_from, model_cls.__table__)
|
||||
tables.append(model_cls.__table__)
|
||||
|
||||
expr = sqlalchemy.sql.select(tables)
|
||||
expr = expr.select_from(select_from)
|
||||
|
||||
if self.filter_clauses:
|
||||
if len(self.filter_clauses) == 1:
|
||||
clause = self.filter_clauses[0]
|
||||
else:
|
||||
clause = sqlalchemy.sql.and_(*self.filter_clauses)
|
||||
expr = expr.where(clause)
|
||||
|
||||
if self.limit_count:
|
||||
expr = expr.limit(self.limit_count)
|
||||
|
||||
if self.query_offset:
|
||||
expr = expr.offset(self.query_offset)
|
||||
|
||||
# print(expr.compile(compile_kwargs={"literal_binds": True}))
|
||||
return expr
|
||||
|
||||
def filter(self, **kwargs):
|
||||
filter_clauses = self.filter_clauses
|
||||
select_related = list(self._select_related)
|
||||
|
||||
if kwargs.get("pk"):
|
||||
pk_name = self.model_cls.__pkname__
|
||||
kwargs[pk_name] = kwargs.pop("pk")
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if "__" in key:
|
||||
parts = key.split("__")
|
||||
|
||||
# Determine if we should treat the final part as a
|
||||
# filter operator or as a related field.
|
||||
if parts[-1] in FILTER_OPERATORS:
|
||||
op = parts[-1]
|
||||
field_name = parts[-2]
|
||||
related_parts = parts[:-2]
|
||||
else:
|
||||
op = "exact"
|
||||
field_name = parts[-1]
|
||||
related_parts = parts[:-1]
|
||||
|
||||
model_cls = self.model_cls
|
||||
if related_parts:
|
||||
# Add any implied select_related
|
||||
related_str = "__".join(related_parts)
|
||||
if related_str not in select_related:
|
||||
select_related.append(related_str)
|
||||
|
||||
# Walk the relationships to the actual model class
|
||||
# against which the comparison is being made.
|
||||
for part in related_parts:
|
||||
model_cls = model_cls.__model_fields__[part].to
|
||||
|
||||
column = model_cls.__table__.columns[field_name]
|
||||
|
||||
else:
|
||||
op = "exact"
|
||||
column = self.table.columns[key]
|
||||
|
||||
# Map the operation code onto SQLAlchemy's ColumnElement
|
||||
# https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
|
||||
op_attr = FILTER_OPERATORS[op]
|
||||
has_escaped_character = False
|
||||
|
||||
if op in ["contains", "icontains"]:
|
||||
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS
|
||||
if c in value)
|
||||
if has_escaped_character:
|
||||
# enable escape modifier
|
||||
for char in self.ESCAPE_CHARACTERS:
|
||||
value = value.replace(char, f'\\{char}')
|
||||
value = f"%{value}%"
|
||||
|
||||
if isinstance(value, Model):
|
||||
value = value.pk
|
||||
|
||||
clause = getattr(column, op_attr)(value)
|
||||
clause.modifiers['escape'] = '\\' if has_escaped_character else None
|
||||
filter_clauses.append(clause)
|
||||
|
||||
return self.__class__(
|
||||
model_cls=self.model_cls,
|
||||
filter_clauses=filter_clauses,
|
||||
select_related=select_related,
|
||||
limit_count=self.limit_count,
|
||||
offset=self.query_offset
|
||||
)
|
||||
|
||||
def select_related(self, related):
|
||||
if not isinstance(related, (list, tuple)):
|
||||
related = [related]
|
||||
|
||||
related = list(self._select_related) + related
|
||||
return self.__class__(
|
||||
model_cls=self.model_cls,
|
||||
filter_clauses=self.filter_clauses,
|
||||
select_related=related,
|
||||
limit_count=self.limit_count,
|
||||
offset=self.query_offset
|
||||
)
|
||||
|
||||
# async def exists(self) -> bool:
|
||||
# expr = self.build_select_expression()
|
||||
# expr = sqlalchemy.exists(expr).select()
|
||||
# return await self.database.fetch_val(expr)
|
||||
|
||||
def limit(self, limit_count: int):
|
||||
return self.__class__(
|
||||
model_cls=self.model_cls,
|
||||
filter_clauses=self.filter_clauses,
|
||||
select_related=self._select_related,
|
||||
limit_count=limit_count,
|
||||
offset=self.query_offset
|
||||
)
|
||||
|
||||
def offset(self, offset: int):
|
||||
return self.__class__(
|
||||
model_cls=self.model_cls,
|
||||
filter_clauses=self.filter_clauses,
|
||||
select_related=self._select_related,
|
||||
limit_count=self.limit_count,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
async def get(self, **kwargs):
|
||||
if kwargs:
|
||||
return await self.filter(**kwargs).get()
|
||||
|
||||
expr = self.build_select_expression().limit(2)
|
||||
rows = await self.database.fetch_all(expr)
|
||||
|
||||
if not rows:
|
||||
raise NoMatch()
|
||||
if len(rows) > 1:
|
||||
raise MultipleMatches()
|
||||
return self.model_cls.from_row(rows[0], select_related=self._select_related)
|
||||
|
||||
async def all(self, **kwargs):
|
||||
if kwargs:
|
||||
return await self.filter(**kwargs).all()
|
||||
|
||||
expr = self.build_select_expression()
|
||||
rows = await self.database.fetch_all(expr)
|
||||
return [
|
||||
self.model_cls.from_row(row, select_related=self._select_related)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def create(self, **kwargs):
|
||||
|
||||
new_kwargs = dict(**kwargs)
|
||||
|
||||
# Remove primary key when None to prevent not null constraint in postgresql.
|
||||
pkname = self.model_cls.__pkname__
|
||||
pk = self.model_cls.__model_fields__[pkname]
|
||||
if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement):
|
||||
del new_kwargs[pkname]
|
||||
|
||||
# substitute related models with their pk
|
||||
for field in self.model_cls.extract_related_names():
|
||||
if field in new_kwargs and new_kwargs.get(field) is not None:
|
||||
new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__)
|
||||
|
||||
# Build the insert expression.
|
||||
expr = self.table.insert()
|
||||
expr = expr.values(**new_kwargs)
|
||||
|
||||
# Execute the insert, and return a new model instance.
|
||||
instance = self.model_cls(**kwargs)
|
||||
instance.pk = await self.database.execute(expr)
|
||||
return instance
|
||||
|
||||
|
||||
class ModelMetaclass(type):
|
||||
def __new__(
|
||||
mcs: type, name: str, bases: Any, attrs: dict
|
||||
@ -52,9 +276,7 @@ class ModelMetaclass(type):
|
||||
attrs['__pkname__'] = pkname
|
||||
|
||||
if not pkname:
|
||||
raise ModelDefinitionError(
|
||||
'Table has to have a primary key.'
|
||||
)
|
||||
raise ModelDefinitionError('Table has to have a primary key.')
|
||||
|
||||
# pydantic model creation
|
||||
pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
|
||||
@ -62,9 +284,9 @@ class ModelMetaclass(type):
|
||||
pydantic_model = create_model(name, __config__=config, **pydantic_fields)
|
||||
attrs['__pydantic_fields__'] = pydantic_fields
|
||||
attrs['__pydantic_model__'] = pydantic_model
|
||||
attrs['__fields__'] = pydantic_model.__fields__
|
||||
attrs['__signature__'] = pydantic_model.__signature__
|
||||
attrs['__annotations__'] = pydantic_model.__annotations__
|
||||
attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__)
|
||||
attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__)
|
||||
attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__)
|
||||
|
||||
attrs['__model_fields__'] = model_fields
|
||||
|
||||
@ -78,9 +300,17 @@ class ModelMetaclass(type):
|
||||
class Model(metaclass=ModelMetaclass):
|
||||
__abstract__ = True
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
objects = QuerySet()
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._orm_id = uuid.uuid4().hex
|
||||
self._orm_saved = False
|
||||
self._orm_relationship_manager = RelationshipManager(self)
|
||||
self._orm_observers = []
|
||||
|
||||
if "pk" in kwargs:
|
||||
kwargs[self.__pkname__] = kwargs.pop("pk")
|
||||
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
|
||||
self.values = self.__pydantic_model__(**kwargs)
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
@ -90,14 +320,19 @@ class Model(metaclass=ModelMetaclass):
|
||||
value = json.dumps(value)
|
||||
except TypeError: # pragma no cover
|
||||
pass
|
||||
value = self.__model_fields__[key].expand_relationship(value, self)
|
||||
setattr(self.values, key, value)
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
if key != '__fields__' and key in self.__fields__:
|
||||
item = getattr(self.values, key)
|
||||
if self.is_conversion_to_json_needed(key) and isinstance(item, str):
|
||||
if key in self._orm_relationship_manager:
|
||||
parent_item = self._orm_relationship_manager.get(key)
|
||||
return parent_item
|
||||
|
||||
item = getattr(self.values, key, None)
|
||||
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
|
||||
try:
|
||||
item = json.loads(item)
|
||||
except TypeError: # pragma no cover
|
||||
@ -106,6 +341,45 @@ class Model(metaclass=ModelMetaclass):
|
||||
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __repr__(self): # pragma no cover
|
||||
return self.values.__repr__()
|
||||
|
||||
# def attach(self, observer: 'Model'):
|
||||
# if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]):
|
||||
# self._orm_observers.append(observer)
|
||||
#
|
||||
# def detach(self, observer: 'Model'):
|
||||
# for ind, obs in enumerate(self._orm_observers):
|
||||
# if obs._orm_id == observer._orm_id:
|
||||
# del self._orm_observers[ind]
|
||||
# break
|
||||
#
|
||||
def notify(self):
|
||||
for obs in self._orm_observers: # pragma no cover
|
||||
obs.orm_update(self)
|
||||
|
||||
def orm_update(self, subject: 'Model') -> None: # pragma no cover
|
||||
print('should be updated here')
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row, select_related: List = None) -> 'Model':
|
||||
item = {}
|
||||
select_related = select_related or []
|
||||
for related in select_related:
|
||||
if "__" in related:
|
||||
first_part, remainder = related.split("__", 1)
|
||||
model_cls = cls.__model_fields__[first_part].to
|
||||
item[first_part] = model_cls.from_row(row, select_related=[remainder])
|
||||
else:
|
||||
model_cls = cls.__model_fields__[related].to
|
||||
item[related] = model_cls.from_row(row)
|
||||
|
||||
for column in cls.__table__.columns:
|
||||
if column.name not in item:
|
||||
item[column.name] = row[column]
|
||||
|
||||
return cls(**item)
|
||||
|
||||
def is_conversion_to_json_needed(self, column_name: str) -> bool:
|
||||
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
|
||||
|
||||
@ -136,17 +410,20 @@ class Model(metaclass=ModelMetaclass):
|
||||
@classmethod
|
||||
def extract_related_names(cls) -> Set:
|
||||
related_names = set()
|
||||
# for name, field in cls.__fields__.items():
|
||||
# if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
|
||||
# related_names.add(name)
|
||||
# elif field.sub_fields and any(
|
||||
# [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]):
|
||||
# related_names.add(name)
|
||||
for name, field in cls.__fields__.items():
|
||||
if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
|
||||
related_names.add(name)
|
||||
# elif field.sub_fields and any(
|
||||
# [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]):
|
||||
# related_names.add(name)
|
||||
return related_names
|
||||
|
||||
def extract_model_db_fields(self) -> Dict:
|
||||
self_fields = self.extract_own_model_fields()
|
||||
self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns}
|
||||
for field in self.extract_related_names():
|
||||
if getattr(self, field) is not None:
|
||||
self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__)
|
||||
return self_fields
|
||||
|
||||
async def save(self) -> int:
|
||||
@ -157,6 +434,7 @@ class Model(metaclass=ModelMetaclass):
|
||||
expr = expr.values(**self_fields)
|
||||
item_id = await self.__database__.execute(expr)
|
||||
setattr(self, 'pk', item_id)
|
||||
self.notify()
|
||||
return item_id
|
||||
|
||||
async def update(self, **kwargs: Any) -> int:
|
||||
@ -169,16 +447,19 @@ class Model(metaclass=ModelMetaclass):
|
||||
expr = self.__table__.update().values(**self_fields).where(
|
||||
self.pk_column == getattr(self, self.__pkname__))
|
||||
result = await self.__database__.execute(expr)
|
||||
self.notify()
|
||||
return result
|
||||
|
||||
async def delete(self) -> int:
|
||||
expr = self.__table__.delete()
|
||||
expr = expr.where(self.pk_column == (getattr(self, self.__pkname__)))
|
||||
result = await self.__database__.execute(expr)
|
||||
self.notify()
|
||||
return result
|
||||
|
||||
async def load(self) -> 'Model':
|
||||
expr = self.__table__.select().where(self.pk_column == self.pk)
|
||||
row = await self.__database__.fetch_one(expr)
|
||||
self.from_dict(dict(row))
|
||||
self.notify()
|
||||
return self
|
||||
|
||||
45
orm/relations.py
Normal file
45
orm/relations.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Dict, Union, List
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
class Relationship:
|
||||
|
||||
def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'):
|
||||
self.fk_side = fk_side
|
||||
self.child = child
|
||||
self.parent = parent
|
||||
self.name = name
|
||||
|
||||
|
||||
class RelationshipManager:
|
||||
|
||||
def __init__(self, model: 'Model'):
|
||||
self._orm_id: str = model._orm_id
|
||||
self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._relations
|
||||
|
||||
def add_related(self, relation: Relationship):
|
||||
if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id:
|
||||
new_relation = Relationship(name=relation.parent.__class__.__name__.lower(),
|
||||
child=relation.parent,
|
||||
parent=relation.child,
|
||||
fk_side='parent')
|
||||
relation.child._orm_relationship_manager.add(new_relation)
|
||||
|
||||
def add(self, relation: Relationship):
|
||||
if relation.name in self._relations:
|
||||
self._relations[relation.name].append(relation)
|
||||
else:
|
||||
self._relations[relation.name] = [relation]
|
||||
self.add_related(relation)
|
||||
|
||||
def get(self, name: str):
|
||||
for rel, relations in self._relations.items():
|
||||
if rel == name:
|
||||
if relations and relations[0].fk_side == 'parent':
|
||||
return relations[0].child
|
||||
else:
|
||||
return [rela.child for rela in relations]
|
||||
231
tests/test_foreign_keys.py
Normal file
231
tests/test_foreign_keys.py
Normal file
@ -0,0 +1,231 @@
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import orm
|
||||
from orm.exceptions import NoMatch, MultipleMatches
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class Album(orm.Model):
|
||||
__tablename__ = "album"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
name = orm.String(length=100)
|
||||
|
||||
|
||||
class Track(orm.Model):
|
||||
__tablename__ = "track"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
album = orm.ForeignKey(Album)
|
||||
title = orm.String(length=100)
|
||||
position = orm.Integer()
|
||||
|
||||
|
||||
class Organisation(orm.Model):
|
||||
__tablename__ = "org"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
ident = orm.String(length=100)
|
||||
|
||||
|
||||
class Team(orm.Model):
|
||||
__tablename__ = "team"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
org = orm.ForeignKey(Organisation)
|
||||
name = orm.String(length=100)
|
||||
|
||||
|
||||
class Member(orm.Model):
|
||||
__tablename__ = "member"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
team = orm.ForeignKey(Team)
|
||||
email = orm.String(length=100)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_crud():
|
||||
async with database:
|
||||
album = Album(name="Malibu")
|
||||
await album.save()
|
||||
track1 = Track(album=album, title="The Bird", position=1)
|
||||
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
|
||||
track3 = Track(album=album, title="The Waters", position=3)
|
||||
await track1.save()
|
||||
await track2.save()
|
||||
await track3.save()
|
||||
|
||||
track = await Track.objects.get(title="The Bird")
|
||||
assert track.album.pk == album.pk
|
||||
assert track.album.name is None
|
||||
await track.album.load()
|
||||
assert track.album.name == "Malibu"
|
||||
|
||||
assert len(album.tracks) == 3
|
||||
assert album.tracks[1].title == "Heart don't stand a chance"
|
||||
|
||||
album1 = await Album.objects.get(name='Malibu')
|
||||
assert album1.pk == 1
|
||||
assert album1.tracks is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_select_related():
|
||||
async with database:
|
||||
album = Album(name="Malibu")
|
||||
await album.save()
|
||||
track1 = Track(album=album, title="The Bird", position=1)
|
||||
track2 = Track(album=album, title="Heart don't stand a chance", position=2)
|
||||
track3 = Track(album=album, title="The Waters", position=3)
|
||||
await track1.save()
|
||||
await track2.save()
|
||||
await track3.save()
|
||||
|
||||
fantasies = Album(name="Fantasies")
|
||||
await fantasies.save()
|
||||
track4 = Track(album=fantasies, title="Help I'm Alive", position=1)
|
||||
track5 = Track(album=fantasies, title="Sick Muse", position=2)
|
||||
track6 = Track(album=fantasies, title="Satellite Mind", position=3)
|
||||
await track4.save()
|
||||
await track5.save()
|
||||
await track6.save()
|
||||
|
||||
track = await Track.objects.select_related("album").get(title="The Bird")
|
||||
assert track.album.name == "Malibu"
|
||||
|
||||
tracks = await Track.objects.select_related("album").all()
|
||||
assert len(tracks) == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fk_filter():
|
||||
async with database:
|
||||
malibu = Album(name="Malibu%")
|
||||
await malibu.save()
|
||||
await Track.objects.create(album=malibu, title="The Bird", position=1)
|
||||
await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2)
|
||||
await Track.objects.create(album=malibu, title="The Waters", position=3)
|
||||
|
||||
fantasies = await Album.objects.create(name="Fantasies")
|
||||
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
|
||||
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
|
||||
await Track.objects.create(album=fantasies, title="Satellite Mind", position=3)
|
||||
|
||||
tracks = await Track.objects.select_related("album").filter(album__name="Fantasies").all()
|
||||
assert len(tracks) == 3
|
||||
for track in tracks:
|
||||
assert track.album.name == "Fantasies"
|
||||
|
||||
tracks = await Track.objects.select_related("album").filter(album__name__icontains="fan").all()
|
||||
assert len(tracks) == 3
|
||||
for track in tracks:
|
||||
assert track.album.name == "Fantasies"
|
||||
|
||||
tracks = await Track.objects.filter(album__name__contains="fan").all()
|
||||
assert len(tracks) == 3
|
||||
for track in tracks:
|
||||
assert track.album.name == "Fantasies"
|
||||
|
||||
tracks = await Track.objects.filter(album__name__contains="Malibu%").all()
|
||||
assert len(tracks) == 3
|
||||
|
||||
tracks = await Track.objects.filter(album=malibu).select_related("album").all()
|
||||
assert len(tracks) == 3
|
||||
for track in tracks:
|
||||
assert track.album.name == "Malibu%"
|
||||
|
||||
tracks = await Track.objects.select_related("album").all(album=malibu)
|
||||
assert len(tracks) == 3
|
||||
for track in tracks:
|
||||
assert track.album.name == "Malibu%"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_fk():
|
||||
async with database:
|
||||
acme = await Organisation.objects.create(ident="ACME Ltd")
|
||||
red_team = await Team.objects.create(org=acme, name="Red Team")
|
||||
blue_team = await Team.objects.create(org=acme, name="Blue Team")
|
||||
await Member.objects.create(team=red_team, email="a@example.org")
|
||||
await Member.objects.create(team=red_team, email="b@example.org")
|
||||
await Member.objects.create(team=blue_team, email="c@example.org")
|
||||
await Member.objects.create(team=blue_team, email="d@example.org")
|
||||
|
||||
other = await Organisation.objects.create(ident="Other ltd")
|
||||
team = await Team.objects.create(org=other, name="Green Team")
|
||||
await Member.objects.create(team=team, email="e@example.org")
|
||||
|
||||
members = await Member.objects.select_related('team__org').filter(team__org__ident="ACME Ltd").all()
|
||||
assert len(members) == 4
|
||||
for member in members:
|
||||
assert member.team.org.ident == "ACME Ltd"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pk_filter():
|
||||
async with database:
|
||||
fantasies = await Album.objects.create(name="Test")
|
||||
await Track.objects.create(album=fantasies, title="Test1", position=1)
|
||||
await Track.objects.create(album=fantasies, title="Test2", position=2)
|
||||
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
||||
tracks = await Track.objects.select_related("album").filter(pk=1).all()
|
||||
assert len(tracks) == 1
|
||||
|
||||
tracks = await Track.objects.select_related("album").filter(position=2, album__name='Test').all()
|
||||
assert len(tracks) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limit_and_offset():
|
||||
async with database:
|
||||
fantasies = await Album.objects.create(name="Limitless")
|
||||
await Track.objects.create(id=None, album=fantasies, title="Sample", position=1)
|
||||
await Track.objects.create(album=fantasies, title="Sample2", position=2)
|
||||
await Track.objects.create(album=fantasies, title="Sample3", position=3)
|
||||
|
||||
tracks = await Track.objects.limit(1).all()
|
||||
assert len(tracks) == 1
|
||||
assert tracks[0].title == "Sample"
|
||||
|
||||
tracks = await Track.objects.limit(1).offset(1).all()
|
||||
assert len(tracks) == 1
|
||||
assert tracks[0].title == "Sample2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_exceptions():
|
||||
async with database:
|
||||
fantasies = await Album.objects.create(name="Test")
|
||||
|
||||
with pytest.raises(NoMatch):
|
||||
await Album.objects.get(name="Test2")
|
||||
|
||||
await Track.objects.create(album=fantasies, title="Test1", position=1)
|
||||
await Track.objects.create(album=fantasies, title="Test2", position=2)
|
||||
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
||||
with pytest.raises(MultipleMatches):
|
||||
await Track.objects.select_related("album").get(album=fantasies)
|
||||
0
tests/test_models.py
Normal file
0
tests/test_models.py
Normal file
Reference in New Issue
Block a user