changed relationshipt to wekrefs

This commit is contained in:
collerek
2020-08-07 05:37:10 +02:00
parent 475dafb6c9
commit 6efb56a2a0
7 changed files with 324 additions and 296 deletions

BIN
.coverage

Binary file not shown.

View File

@ -215,15 +215,3 @@ class ForeignKey(BaseField):
model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True)
return model return model
# def register_relationship(self):
# child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
# if not child_model_name in model._orm_relationship_manager:
# model._orm_relationship_manager.add(
# Relationship(name=child_model_name, child=child, parent=model, fk_side='child'))
# model.__fields__[child_model_name] = ModelField(name=child_model_name,
# type_=Optional[child.__pydantic_model__],
# model_config=child.__pydantic_model__.__config__,
# class_validators=child.__pydantic_model__.__validators__)
# model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True)
# breakpoint()

View File

@ -2,21 +2,22 @@ import copy
import inspect import inspect
import json import json
import uuid import uuid
from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple
from typing import Set, Dict from typing import Set, Dict
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseModel, BaseConfig, create_model from pydantic import BaseModel, BaseConfig, create_model
from orm.exceptions import ModelDefinitionError, NoMatch, MultipleMatches import orm.queryset as qry
from orm.exceptions import ModelDefinitionError
from orm.fields import BaseField, ForeignKey from orm.fields import BaseField, ForeignKey
from orm.relations import RelationshipManager from orm.relations import RelationshipManager
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()
def parse_pydantic_field_from_model_fields(object_dict: dict): def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
pydantic_fields = {field_name: ( pydantic_fields = {field_name: (
base_field.__type__, base_field.__type__,
... if base_field.is_required else base_field.default_value ... if base_field.is_required else base_field.default_value
@ -26,8 +27,10 @@ def parse_pydantic_field_from_model_fields(object_dict: dict):
return pydantic_fields return pydantic_fields
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict): def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str],
pkname = None List[sqlalchemy.Column],
Dict[str, BaseField]]:
pkname: Optional[str] = None
columns: List[sqlalchemy.Column] = [] columns: List[sqlalchemy.Column] = []
model_fields: Dict[str, BaseField] = {} model_fields: Dict[str, BaseField] = {}
@ -39,243 +42,17 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict):
pkname = field_name pkname = field_name
if isinstance(field, ForeignKey): if isinstance(field, ForeignKey):
reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's' reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's'
relationship_manager.add_relation_type(name + '_' + field.to.__name__.lower(), reverse_name) relation_name = name + '_' + field.to.__name__.lower()
relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename)
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
return pkname, columns, model_fields return pkname, columns, model_fields
FILTER_OPERATORS = { def get_pydantic_base_orm_config() -> Type[BaseConfig]:
"exact": "__eq__", class Config(BaseConfig):
"iexact": "ilike", orm_mode = True
"contains": "like",
"icontains": "ilike",
"in": "in_",
"gt": "__gt__",
"gte": "__ge__",
"lt": "__lt__",
"lte": "__le__",
}
return Config
class QuerySet:
ESCAPE_CHARACTERS = ['%', '_']
def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None,
limit_count: int = None, offset: int = None):
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count
self.query_offset = offset
def __get__(self, instance, owner):
return self.__class__(model_cls=owner)
@property
def database(self):
return self.model_cls.__database__
@property
def table(self):
return self.model_cls.__table__
def build_select_expression(self):
tables = [self.table]
select_from = self.table
for item in self._select_related:
model_cls = self.model_cls
select_from = self.table
for part in item.split("__"):
model_cls = model_cls.__model_fields__[part].to
select_from = sqlalchemy.sql.join(select_from, model_cls.__table__)
tables.append(model_cls.__table__)
expr = sqlalchemy.sql.select(tables)
expr = expr.select_from(select_from)
if self.filter_clauses:
if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0]
else:
clause = sqlalchemy.sql.and_(*self.filter_clauses)
expr = expr.where(clause)
if self.limit_count:
expr = expr.limit(self.limit_count)
if self.query_offset:
expr = expr.offset(self.query_offset)
# print(expr.compile(compile_kwargs={"literal_binds": True}))
return expr
def filter(self, **kwargs):
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
if kwargs.get("pk"):
pk_name = self.model_cls.__pkname__
kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items():
if "__" in key:
parts = key.split("__")
# Determine if we should treat the final part as a
# filter operator or as a related field.
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
model_cls = self.model_cls
if related_parts:
# Add any implied select_related
related_str = "__".join(related_parts)
if related_str not in select_related:
select_related.append(related_str)
# Walk the relationships to the actual model class
# against which the comparison is being made.
for part in related_parts:
model_cls = model_cls.__model_fields__[part].to
column = model_cls.__table__.columns[field_name]
else:
op = "exact"
column = self.table.columns[key]
# Map the operation code onto SQLAlchemy's ColumnElement
# https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
op_attr = FILTER_OPERATORS[op]
has_escaped_character = False
if op in ["contains", "icontains"]:
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS
if c in value)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f'\\{char}')
value = f"%{value}%"
if isinstance(value, Model):
value = value.pk
clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None
filter_clauses.append(clause)
return self.__class__(
model_cls=self.model_cls,
filter_clauses=filter_clauses,
select_related=select_related,
limit_count=self.limit_count,
offset=self.query_offset
)
def select_related(self, related):
if not isinstance(related, (list, tuple)):
related = [related]
related = list(self._select_related) + related
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=related,
limit_count=self.limit_count,
offset=self.query_offset
)
async def exists(self) -> bool:
expr = self.build_select_expression()
expr = sqlalchemy.exists(expr).select()
return await self.database.fetch_val(expr)
async def count(self) -> int:
expr = self.build_select_expression().alias("subquery_for_count")
expr = sqlalchemy.func.count().select().select_from(expr)
return await self.database.fetch_val(expr)
def limit(self, limit_count: int):
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=limit_count,
offset=self.query_offset
)
def offset(self, offset: int):
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=offset
)
async def first(self, **kwargs):
if kwargs:
return await self.filter(**kwargs).first()
rows = await self.limit(1).all()
if rows:
return rows[0]
async def get(self, **kwargs):
if kwargs:
return await self.filter(**kwargs).get()
expr = self.build_select_expression().limit(2)
rows = await self.database.fetch_all(expr)
if not rows:
raise NoMatch()
if len(rows) > 1:
raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related)
async def all(self, **kwargs):
if kwargs:
return await self.filter(**kwargs).all()
expr = self.build_select_expression()
rows = await self.database.fetch_all(expr)
return [
self.model_cls.from_row(row, select_related=self._select_related)
for row in rows
]
async def create(self, **kwargs):
new_kwargs = dict(**kwargs)
# Remove primary key when None to prevent not null constraint in postgresql.
pkname = self.model_cls.__pkname__
pk = self.model_cls.__model_fields__[pkname]
if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement):
del new_kwargs[pkname]
# substitute related models with their pk
for field in self.model_cls.extract_related_names():
if field in new_kwargs and new_kwargs.get(field) is not None:
new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__)
# Build the insert expression.
expr = self.table.insert()
expr = expr.values(**new_kwargs)
# Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs)
instance.pk = await self.database.execute(expr)
return instance
class ModelMetaclass(type): class ModelMetaclass(type):
@ -293,7 +70,7 @@ class ModelMetaclass(type):
metadata = attrs["__metadata__"] metadata = attrs["__metadata__"]
# sqlalchemy table creation # sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs) pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename)
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
attrs['__columns__'] = columns attrs['__columns__'] = columns
attrs['__pkname__'] = pkname attrs['__pkname__'] = pkname
@ -303,8 +80,7 @@ class ModelMetaclass(type):
# pydantic model creation # pydantic model creation
pydantic_fields = parse_pydantic_field_from_model_fields(attrs) pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
config = type('Config', (BaseConfig,), {'orm_mode': True}) pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields)
pydantic_model = create_model(name, __config__=config, **pydantic_fields)
attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_fields__'] = pydantic_fields
attrs['__pydantic_model__'] = pydantic_model attrs['__pydantic_model__'] = pydantic_model
attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__)
@ -330,21 +106,22 @@ class Model(list, metaclass=ModelMetaclass):
__pydantic_model__: Type[BaseModel] __pydantic_model__: Type[BaseModel]
__pkname__: str __pkname__: str
objects = QuerySet() objects = qry.QuerySet()
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
self._orm_id: str = uuid.uuid4().hex self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False self._orm_saved: bool = False
self._orm_relationship_manager: RelationshipManager = relationship_manager self._orm_relationship_manager: RelationshipManager = relationship_manager
self._orm_observers: List['Model'] = []
self.values: Optional[BaseModel] = None self.values: Optional[BaseModel] = None
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk") kwargs[self.__pkname__] = kwargs.pop("pk")
# breakpoint()
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
self.values = self.__pydantic_model__(**kwargs) self.values = self.__pydantic_model__(**kwargs)
def __del__(self):
self._orm_relationship_manager.deregister(self)
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Any) -> None:
if key in self.__fields__: if key in self.__fields__:
if self.is_conversion_to_json_needed(key) and not isinstance(value, str): if self.is_conversion_to_json_needed(key) and not isinstance(value, str):
@ -378,23 +155,6 @@ class Model(list, metaclass=ModelMetaclass):
def __repr__(self): # pragma no cover def __repr__(self): # pragma no cover
return self.values.__repr__() return self.values.__repr__()
# def attach(self, observer: 'Model'):
# if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]):
# self._orm_observers.append(observer)
#
# def detach(self, observer: 'Model'):
# for ind, obs in enumerate(self._orm_observers):
# if obs._orm_id == observer._orm_id:
# del self._orm_observers[ind]
# break
#
def notify(self):
for obs in self._orm_observers: # pragma no cover
obs.orm_update(self)
def orm_update(self, subject: 'Model') -> None: # pragma no cover
print('should be updated here')
@classmethod @classmethod
def from_row(cls, row, select_related: List = None) -> 'Model': def from_row(cls, row, select_related: List = None) -> 'Model':
item = {} item = {}
@ -412,20 +172,19 @@ class Model(list, metaclass=ModelMetaclass):
if column.name not in item: if column.name not in item:
item[column.name] = row[column] item[column.name] = row[column]
# breakpoint()
return cls(**item) return cls(**item)
@classmethod # @classmethod
def validate(cls, value: Any) -> 'BaseModel': # pragma no cover # def validate(cls, value: Any) -> 'BaseModel': # pragma no cover
return cls.__pydantic_model__.validate(value=value) # return cls.__pydantic_model__.validate(value=value)
@classmethod @classmethod
def __get_validators__(cls): # pragma no cover def __get_validators__(cls): # pragma no cover
yield cls.__pydantic_model__.validate yield cls.__pydantic_model__.validate
@classmethod # @classmethod
def schema(cls, by_alias: bool = True): # pragma no cover # def schema(cls, by_alias: bool = True): # pragma no cover
return cls.__pydantic_model__.schema(by_alias=by_alias) # return cls.__pydantic_model__.schema(by_alias=by_alias)
def is_conversion_to_json_needed(self, column_name: str) -> bool: def is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json return self.__model_fields__.get(column_name).__type__ == pydantic.Json
@ -460,9 +219,6 @@ class Model(list, metaclass=ModelMetaclass):
for name, field in cls.__fields__.items(): for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel):
related_names.add(name) related_names.add(name)
# elif field.sub_fields and any(
# [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]):
# related_names.add(name)
return related_names return related_names
def extract_model_db_fields(self) -> Dict: def extract_model_db_fields(self) -> Dict:
@ -481,7 +237,6 @@ class Model(list, metaclass=ModelMetaclass):
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
item_id = await self.__database__.execute(expr) item_id = await self.__database__.execute(expr)
setattr(self, 'pk', item_id) setattr(self, 'pk', item_id)
self.notify()
return item_id return item_id
async def update(self, **kwargs: Any) -> int: async def update(self, **kwargs: Any) -> int:
@ -494,19 +249,16 @@ class Model(list, metaclass=ModelMetaclass):
expr = self.__table__.update().values(**self_fields).where( expr = self.__table__.update().values(**self_fields).where(
self.pk_column == getattr(self, self.__pkname__)) self.pk_column == getattr(self, self.__pkname__))
result = await self.__database__.execute(expr) result = await self.__database__.execute(expr)
self.notify()
return result return result
async def delete(self) -> int: async def delete(self) -> int:
expr = self.__table__.delete() expr = self.__table__.delete()
expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) expr = expr.where(self.pk_column == (getattr(self, self.__pkname__)))
result = await self.__database__.execute(expr) result = await self.__database__.execute(expr)
self.notify()
return result return result
async def load(self) -> 'Model': async def load(self) -> 'Model':
expr = self.__table__.select().where(self.pk_column == self.pk) expr = self.__table__.select().where(self.pk_column == self.pk)
row = await self.__database__.fetch_one(expr) row = await self.__database__.fetch_one(expr)
self.from_dict(dict(row)) self.from_dict(dict(row))
self.notify()
return self return self

242
orm/queryset.py Normal file
View File

@ -0,0 +1,242 @@
from typing import List, TYPE_CHECKING
import sqlalchemy
import orm
from orm.exceptions import NoMatch, MultipleMatches
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
FILTER_OPERATORS = {
"exact": "__eq__",
"iexact": "ilike",
"contains": "like",
"icontains": "ilike",
"in": "in_",
"gt": "__gt__",
"gte": "__ge__",
"lt": "__lt__",
"lte": "__le__",
}
class QuerySet:
ESCAPE_CHARACTERS = ['%', '_']
def __init__(self, model_cls: 'Model' = None, filter_clauses: List = None, select_related: List = None,
limit_count: int = None, offset: int = None):
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
self.limit_count = limit_count
self.query_offset = offset
def __get__(self, instance, owner):
return self.__class__(model_cls=owner)
@property
def database(self):
return self.model_cls.__database__
@property
def table(self):
return self.model_cls.__table__
def build_select_expression(self):
tables = [self.table]
select_from = self.table
for item in self._select_related:
model_cls = self.model_cls
select_from = self.table
for part in item.split("__"):
model_cls = model_cls.__model_fields__[part].to
select_from = sqlalchemy.sql.join(select_from, model_cls.__table__)
tables.append(model_cls.__table__)
expr = sqlalchemy.sql.select(tables)
expr = expr.select_from(select_from)
if self.filter_clauses:
if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0]
else:
clause = sqlalchemy.sql.and_(*self.filter_clauses)
expr = expr.where(clause)
if self.limit_count:
expr = expr.limit(self.limit_count)
if self.query_offset:
expr = expr.offset(self.query_offset)
print(expr.compile(compile_kwargs={"literal_binds": True}))
return expr
def filter(self, **kwargs):
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
if kwargs.get("pk"):
pk_name = self.model_cls.__pkname__
kwargs[pk_name] = kwargs.pop("pk")
for key, value in kwargs.items():
if "__" in key:
parts = key.split("__")
# Determine if we should treat the final part as a
# filter operator or as a related field.
if parts[-1] in FILTER_OPERATORS:
op = parts[-1]
field_name = parts[-2]
related_parts = parts[:-2]
else:
op = "exact"
field_name = parts[-1]
related_parts = parts[:-1]
model_cls = self.model_cls
if related_parts:
# Add any implied select_related
related_str = "__".join(related_parts)
if related_str not in select_related:
select_related.append(related_str)
# Walk the relationships to the actual model class
# against which the comparison is being made.
for part in related_parts:
model_cls = model_cls.__model_fields__[part].to
column = model_cls.__table__.columns[field_name]
else:
op = "exact"
column = self.table.columns[key]
# Map the operation code onto SQLAlchemy's ColumnElement
# https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
op_attr = FILTER_OPERATORS[op]
has_escaped_character = False
if op in ["contains", "icontains"]:
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS
if c in value)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f'\\{char}')
value = f"%{value}%"
if isinstance(value, orm.Model):
value = value.pk
clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None
filter_clauses.append(clause)
return self.__class__(
model_cls=self.model_cls,
filter_clauses=filter_clauses,
select_related=select_related,
limit_count=self.limit_count,
offset=self.query_offset
)
def select_related(self, related):
if not isinstance(related, (list, tuple)):
related = [related]
related = list(self._select_related) + related
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=related,
limit_count=self.limit_count,
offset=self.query_offset
)
async def exists(self) -> bool:
expr = self.build_select_expression()
expr = sqlalchemy.exists(expr).select()
return await self.database.fetch_val(expr)
async def count(self) -> int:
expr = self.build_select_expression().alias("subquery_for_count")
expr = sqlalchemy.func.count().select().select_from(expr)
return await self.database.fetch_val(expr)
def limit(self, limit_count: int):
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=limit_count,
offset=self.query_offset
)
def offset(self, offset: int):
return self.__class__(
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=self.limit_count,
offset=offset
)
async def first(self, **kwargs):
if kwargs:
return await self.filter(**kwargs).first()
rows = await self.limit(1).all()
if rows:
return rows[0]
async def get(self, **kwargs):
if kwargs:
return await self.filter(**kwargs).get()
expr = self.build_select_expression().limit(2)
rows = await self.database.fetch_all(expr)
if not rows:
raise NoMatch()
if len(rows) > 1:
raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related)
async def all(self, **kwargs):
if kwargs:
return await self.filter(**kwargs).all()
expr = self.build_select_expression()
rows = await self.database.fetch_all(expr)
return [
self.model_cls.from_row(row, select_related=self._select_related)
for row in rows
]
async def create(self, **kwargs):
new_kwargs = dict(**kwargs)
# Remove primary key when None to prevent not null constraint in postgresql.
pkname = self.model_cls.__pkname__
pk = self.model_cls.__model_fields__[pkname]
if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement):
del new_kwargs[pkname]
# substitute related models with their pk
for field in self.model_cls.extract_related_names():
if field in new_kwargs and new_kwargs.get(field) is not None:
new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__)
# Build the insert expression.
expr = self.table.insert()
expr = expr.values(**new_kwargs)
# Execute the insert, and return a new model instance.
instance = self.model_cls(**kwargs)
instance.pk = await self.database.execute(expr)
return instance

View File

@ -1,20 +1,49 @@
import pprint
import string
import uuid
from random import choices
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from weakref import proxy
from sqlalchemy import text
from orm.fields import ForeignKey
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from orm.models import Model from orm.models import Model
def get_table_alias():
return ''.join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
def get_relation_config(relation_type: str, table_name: str, field: ForeignKey):
alias = get_table_alias()
config = {'type': relation_type,
'table_alias': alias,
'source_table': table_name if relation_type == 'primary' else field.to.__tablename__,
'target_table': field.to.__tablename__ if relation_type == 'primary' else table_name
}
return config
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self._relations = dict() self._relations = dict()
def add_relation_type(self, relations_key, reverse_key): def add_relation_type(self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str):
print(relations_key, reverse_key) print(relations_key, reverse_key)
if relations_key not in self._relations: if relations_key not in self._relations:
self._relations[relations_key] = {'type': 'primary'} self._relations[relations_key] = get_relation_config('primary', table_name, field)
if reverse_key not in self._relations: if reverse_key not in self._relations:
self._relations[reverse_key] = {'type': 'reverse'} self._relations[reverse_key] = get_relation_config('reverse', table_name, field)
def deregister(self, model: 'Model'):
for rel_type in self._relations.keys():
if model.__class__.__name__.lower() in rel_type.lower():
if model._orm_id in self._relations[rel_type]:
del self._relations[rel_type][model._orm_id]
def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False): def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False):
parent_id = parent._orm_id parent_id = parent._orm_id
@ -22,9 +51,10 @@ class RelationshipManager:
if virtual: if virtual:
child_name, parent_name = parent_name, child_name child_name, parent_name = parent_name, child_name
child_id, parent_id = parent_id, child_id child_id, parent_id = parent_id, child_id
child, parent = parent, child child, parent = parent, proxy(child)
self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append( else:
child) child = proxy(child)
self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(child)
self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent)
def contains(self, relations_key: str, object: 'Model'): def contains(self, relations_key: str, object: 'Model'):
@ -40,7 +70,7 @@ class RelationshipManager:
return self._relations[relations_key][object._orm_id] return self._relations[relations_key][object._orm_id]
def __str__(self): # pragma no cover def __str__(self): # pragma no cover
return ''.join(self._relations[rel].__str__() for rel in self._relations) return pprint.pformat(self._relations, indent=4, width=1)
def __repr__(self): # pragma no cover def __repr__(self): # pragma no cover
return self.__str__() return self.__str__()

View File

@ -13,7 +13,7 @@ metadata = sqlalchemy.MetaData()
class Category(orm.Model): class Category(orm.Model):
__tablename__ = "cateries" __tablename__ = "categories"
__metadata__ = metadata __metadata__ = metadata
__database__ = database __database__ = database
@ -22,7 +22,7 @@ class Category(orm.Model):
class Item(orm.Model): class Item(orm.Model):
__tablename__ = "users" __tablename__ = "items"
__metadata__ = metadata __metadata__ = metadata
__database__ = database __database__ = database

View File

@ -19,7 +19,7 @@ class SchoolClass(orm.Model):
class Category(orm.Model): class Category(orm.Model):
__tablename__ = "cateogories" __tablename__ = "categories"
__metadata__ = metadata __metadata__ = metadata
__database__ = database __database__ = database
@ -75,3 +75,19 @@ async def test_model_multiple_instances_of_same_table_in_schema():
assert classes[0].students[0].schoolclass.name is None assert classes[0].students[0].schoolclass.name is None
await classes[0].students[0].schoolclass.load() await classes[0].students[0].schoolclass.load()
assert classes[0].students[0].schoolclass.name == 'Math' assert classes[0].students[0].schoolclass.name == 'Math'
@pytest.mark.asyncio
async def test_right_tables_join():
async with database:
class1 = await SchoolClass.objects.create(name="Math")
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane'
breakpoint()
assert classes[0].teachers[0].category.name == 'Domestic'