From 6efb56a2a0c63afcc2e610bf3153e5874f133320 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 7 Aug 2020 05:37:10 +0200 Subject: [PATCH] changed relationshipt to wekrefs --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 12 -- orm/models.py | 300 +++------------------------------ orm/queryset.py | 242 ++++++++++++++++++++++++++ orm/relations.py | 44 ++++- tests/test_fastapi_usage.py | 4 +- tests/test_same_table_joins.py | 18 +- 7 files changed, 324 insertions(+), 296 deletions(-) create mode 100644 orm/queryset.py diff --git a/.coverage b/.coverage index 8cfa156b16a6247a46fef9e8643a14aac1b27338..ad89c9e7f4286dd936db7dc6b3d5e305f71bfba6 100644 GIT binary patch delta 702 zcmZ9KUr19?9LMkN-0R)F+xh+GpCrL@{vrPaDmc-ET)o&#yjT(dwa z2p{f4k)nGjB!viuiJ7D5B@7#ZLM5pL?IGAkEs!;Boom!XPv?BTzwbGG4xBL`8S{}R z^bG+Bv5 zFO-YMO5)Uax+sduG18_}wwYNBUg_)6kIE&aRS*9d8;VKGKK*BWQO>mh1UZc3SKP*B z^y6*p#Aa;3TFgQMpCQP6=iqEA2)tYXeX-E0do%VeTAe(SLs@i^vgm}(I~Cm+4lKns zB96dn(CLYPnTYZkCRQjhUB8>0j!(w6LZRa}n@!3$vea(WO2%`%3fmr1Ht$zOCwEh1 z1|TbFV5`O8IT&k5jCyxGa)HExGt$*~PGaNJUifJ^a%16gmU|(wXp8Q}UBih;^HhaY zn8I33DW0A0Z==b@-QLaJiB4ZM><{$3@h)v&8eVXFy~*XR9o`_aVw%|9-JPoi#_KGA zJ`NN31H%}?Eqssfa0OrCbDYP=ID5L7dm57l#kvJR6RqNZ6O7 I+OuHYZ`F+ILI3~& delta 614 zcmY+9Ur1A77{<@x`_BHH?fc$oMpCI0Nkc&mTh2eVE~4z>9H9`}g+`!5BRAW!VZx>t z-E^VQ2MRejyC~L$RCb8S41>BbMIxFIh@Ng*6r~pvhpc9wi!Qo*p5OcM@Vp~&G!jQo z>g!~e^O(E2p}yJW(wplR9ZpJu0wm!++=2^GNptigP0;JKl_HWPQzWPtI8-84Q6zM* zX!}9v?YTD46AaZ33^fuY46?^cIh$09L_3>SYO(Dvs@!+EyT{k(@4L|#@bhScjD-Gt z)U7(O6$yd7=~oLx3yZ2X*o*}4UOih-?Rri%iF{W`fMKYDizGyMX%bfGLwG=g^d!Y( zg~a%-cEZkg9>vs+2H9v=3A?KlV^WK*v*2$Kx?q$egyx%sE~bwQ^Z( zUBB138T)0ja6>J!&AbQAceb-e_Q*MtBw9zQDH0y~`dizbnO|Eo{8Jla;g*S+@JlV5 zUe`R~=bNX#G=(B30=cQ~xUuaubZO*Wv}JjD!=H(4W#%)PtZ0#VYoX-x`N}a@XYj_; z=o(uXZ!IN|5FiU%@DoyCz$f?si|_{K;2AuE1po3ju(&4a)7m8`qA<(MuS+>o+&)KdTe diff --git a/orm/fields.py b/orm/fields.py index 58fb4da..6f9c221 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -215,15 +215,3 @@ class ForeignKey(BaseField): model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) return model - - # def register_relationship(self): - # child_model_name = self.related_name or child.__class__.__name__.lower() + 's' - # if not child_model_name in model._orm_relationship_manager: - # model._orm_relationship_manager.add( - # Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) - # model.__fields__[child_model_name] = ModelField(name=child_model_name, - # type_=Optional[child.__pydantic_model__], - # model_config=child.__pydantic_model__.__config__, - # class_validators=child.__pydantic_model__.__validators__) - # model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) - # breakpoint() diff --git a/orm/models.py b/orm/models.py index 0d93ee1..508ba00 100644 --- a/orm/models.py +++ b/orm/models.py @@ -2,21 +2,22 @@ import copy import inspect import json import uuid -from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar +from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple from typing import Set, Dict import pydantic import sqlalchemy from pydantic import BaseModel, BaseConfig, create_model -from orm.exceptions import ModelDefinitionError, NoMatch, MultipleMatches +import orm.queryset as qry +from orm.exceptions import ModelDefinitionError from orm.fields import BaseField, ForeignKey from orm.relations import RelationshipManager relationship_manager = RelationshipManager() -def parse_pydantic_field_from_model_fields(object_dict: dict): +def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: pydantic_fields = {field_name: ( base_field.__type__, ... if base_field.is_required else base_field.default_value @@ -26,8 +27,10 @@ def parse_pydantic_field_from_model_fields(object_dict: dict): return pydantic_fields -def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict): - pkname = None +def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str], + List[sqlalchemy.Column], + Dict[str, BaseField]]: + pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] model_fields: Dict[str, BaseField] = {} @@ -39,243 +42,17 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict): pkname = field_name if isinstance(field, ForeignKey): reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's' - relationship_manager.add_relation_type(name + '_' + field.to.__name__.lower(), reverse_name) + relation_name = name + '_' + field.to.__name__.lower() + relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename) columns.append(field.get_column(field_name)) return pkname, columns, model_fields -FILTER_OPERATORS = { - "exact": "__eq__", - "iexact": "ilike", - "contains": "like", - "icontains": "ilike", - "in": "in_", - "gt": "__gt__", - "gte": "__ge__", - "lt": "__lt__", - "lte": "__le__", -} +def get_pydantic_base_orm_config() -> Type[BaseConfig]: + class Config(BaseConfig): + orm_mode = True - -class QuerySet: - ESCAPE_CHARACTERS = ['%', '_'] - - def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, - limit_count: int = None, offset: int = None): - self.model_cls = model_cls - self.filter_clauses = [] if filter_clauses is None else filter_clauses - self._select_related = [] if select_related is None else select_related - self.limit_count = limit_count - self.query_offset = offset - - def __get__(self, instance, owner): - return self.__class__(model_cls=owner) - - @property - def database(self): - return self.model_cls.__database__ - - @property - def table(self): - return self.model_cls.__table__ - - def build_select_expression(self): - tables = [self.table] - select_from = self.table - - for item in self._select_related: - model_cls = self.model_cls - select_from = self.table - for part in item.split("__"): - model_cls = model_cls.__model_fields__[part].to - select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) - tables.append(model_cls.__table__) - - expr = sqlalchemy.sql.select(tables) - expr = expr.select_from(select_from) - - if self.filter_clauses: - if len(self.filter_clauses) == 1: - clause = self.filter_clauses[0] - else: - clause = sqlalchemy.sql.and_(*self.filter_clauses) - expr = expr.where(clause) - - if self.limit_count: - expr = expr.limit(self.limit_count) - - if self.query_offset: - expr = expr.offset(self.query_offset) - - # print(expr.compile(compile_kwargs={"literal_binds": True})) - return expr - - def filter(self, **kwargs): - filter_clauses = self.filter_clauses - select_related = list(self._select_related) - - if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ - kwargs[pk_name] = kwargs.pop("pk") - - for key, value in kwargs.items(): - if "__" in key: - parts = key.split("__") - - # Determine if we should treat the final part as a - # filter operator or as a related field. - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] - - model_cls = self.model_cls - if related_parts: - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - for part in related_parts: - model_cls = model_cls.__model_fields__[part].to - - column = model_cls.__table__.columns[field_name] - - else: - op = "exact" - column = self.table.columns[key] - - # Map the operation code onto SQLAlchemy's ColumnElement - # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement - op_attr = FILTER_OPERATORS[op] - has_escaped_character = False - - if op in ["contains", "icontains"]: - has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS - if c in value) - if has_escaped_character: - # enable escape modifier - for char in self.ESCAPE_CHARACTERS: - value = value.replace(char, f'\\{char}') - value = f"%{value}%" - - if isinstance(value, Model): - value = value.pk - - clause = getattr(column, op_attr)(value) - clause.modifiers['escape'] = '\\' if has_escaped_character else None - filter_clauses.append(clause) - - return self.__class__( - model_cls=self.model_cls, - filter_clauses=filter_clauses, - select_related=select_related, - limit_count=self.limit_count, - offset=self.query_offset - ) - - def select_related(self, related): - if not isinstance(related, (list, tuple)): - related = [related] - - related = list(self._select_related) + related - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=related, - limit_count=self.limit_count, - offset=self.query_offset - ) - - async def exists(self) -> bool: - expr = self.build_select_expression() - expr = sqlalchemy.exists(expr).select() - return await self.database.fetch_val(expr) - - async def count(self) -> int: - expr = self.build_select_expression().alias("subquery_for_count") - expr = sqlalchemy.func.count().select().select_from(expr) - return await self.database.fetch_val(expr) - - def limit(self, limit_count: int): - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=self._select_related, - limit_count=limit_count, - offset=self.query_offset - ) - - def offset(self, offset: int): - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=self._select_related, - limit_count=self.limit_count, - offset=offset - ) - - async def first(self, **kwargs): - if kwargs: - return await self.filter(**kwargs).first() - - rows = await self.limit(1).all() - if rows: - return rows[0] - - async def get(self, **kwargs): - if kwargs: - return await self.filter(**kwargs).get() - - expr = self.build_select_expression().limit(2) - rows = await self.database.fetch_all(expr) - - if not rows: - raise NoMatch() - if len(rows) > 1: - raise MultipleMatches() - return self.model_cls.from_row(rows[0], select_related=self._select_related) - - async def all(self, **kwargs): - if kwargs: - return await self.filter(**kwargs).all() - - expr = self.build_select_expression() - rows = await self.database.fetch_all(expr) - return [ - self.model_cls.from_row(row, select_related=self._select_related) - for row in rows - ] - - async def create(self, **kwargs): - - new_kwargs = dict(**kwargs) - - # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.__model_fields__[pkname] - if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): - del new_kwargs[pkname] - - # substitute related models with their pk - for field in self.model_cls.extract_related_names(): - if field in new_kwargs and new_kwargs.get(field) is not None: - new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) - - # Build the insert expression. - expr = self.table.insert() - expr = expr.values(**new_kwargs) - - # Execute the insert, and return a new model instance. - instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) - return instance + return Config class ModelMetaclass(type): @@ -293,7 +70,7 @@ class ModelMetaclass(type): metadata = attrs["__metadata__"] # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs) + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename) attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__columns__'] = columns attrs['__pkname__'] = pkname @@ -303,8 +80,7 @@ class ModelMetaclass(type): # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - config = type('Config', (BaseConfig,), {'orm_mode': True}) - pydantic_model = create_model(name, __config__=config, **pydantic_fields) + pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields) attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_model__'] = pydantic_model attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) @@ -330,21 +106,22 @@ class Model(list, metaclass=ModelMetaclass): __pydantic_model__: Type[BaseModel] __pkname__: str - objects = QuerySet() + objects = qry.QuerySet() def __init__(self, *args, **kwargs) -> None: self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False self._orm_relationship_manager: RelationshipManager = relationship_manager - self._orm_observers: List['Model'] = [] self.values: Optional[BaseModel] = None if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") - # breakpoint() kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} self.values = self.__pydantic_model__(**kwargs) + def __del__(self): + self._orm_relationship_manager.deregister(self) + def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: if self.is_conversion_to_json_needed(key) and not isinstance(value, str): @@ -378,23 +155,6 @@ class Model(list, metaclass=ModelMetaclass): def __repr__(self): # pragma no cover return self.values.__repr__() - # def attach(self, observer: 'Model'): - # if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]): - # self._orm_observers.append(observer) - # - # def detach(self, observer: 'Model'): - # for ind, obs in enumerate(self._orm_observers): - # if obs._orm_id == observer._orm_id: - # del self._orm_observers[ind] - # break - # - def notify(self): - for obs in self._orm_observers: # pragma no cover - obs.orm_update(self) - - def orm_update(self, subject: 'Model') -> None: # pragma no cover - print('should be updated here') - @classmethod def from_row(cls, row, select_related: List = None) -> 'Model': item = {} @@ -412,20 +172,19 @@ class Model(list, metaclass=ModelMetaclass): if column.name not in item: item[column.name] = row[column] - # breakpoint() return cls(**item) - @classmethod - def validate(cls, value: Any) -> 'BaseModel': # pragma no cover - return cls.__pydantic_model__.validate(value=value) + # @classmethod + # def validate(cls, value: Any) -> 'BaseModel': # pragma no cover + # return cls.__pydantic_model__.validate(value=value) @classmethod def __get_validators__(cls): # pragma no cover yield cls.__pydantic_model__.validate - @classmethod - def schema(cls, by_alias: bool = True): # pragma no cover - return cls.__pydantic_model__.schema(by_alias=by_alias) + # @classmethod + # def schema(cls, by_alias: bool = True): # pragma no cover + # return cls.__pydantic_model__.schema(by_alias=by_alias) def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @@ -460,9 +219,6 @@ class Model(list, metaclass=ModelMetaclass): for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): related_names.add(name) - # elif field.sub_fields and any( - # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): - # related_names.add(name) return related_names def extract_model_db_fields(self) -> Dict: @@ -481,7 +237,6 @@ class Model(list, metaclass=ModelMetaclass): expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) setattr(self, 'pk', item_id) - self.notify() return item_id async def update(self, **kwargs: Any) -> int: @@ -494,19 +249,16 @@ class Model(list, metaclass=ModelMetaclass): expr = self.__table__.update().values(**self_fields).where( self.pk_column == getattr(self, self.__pkname__)) result = await self.__database__.execute(expr) - self.notify() return result async def delete(self) -> int: expr = self.__table__.delete() expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) result = await self.__database__.execute(expr) - self.notify() return result async def load(self) -> 'Model': expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) - self.notify() return self diff --git a/orm/queryset.py b/orm/queryset.py new file mode 100644 index 0000000..d0d6b2c --- /dev/null +++ b/orm/queryset.py @@ -0,0 +1,242 @@ +from typing import List, TYPE_CHECKING + +import sqlalchemy + +import orm +from orm.exceptions import NoMatch, MultipleMatches + +if TYPE_CHECKING: # pragma no cover + from orm.models import Model + +FILTER_OPERATORS = { + "exact": "__eq__", + "iexact": "ilike", + "contains": "like", + "icontains": "ilike", + "in": "in_", + "gt": "__gt__", + "gte": "__ge__", + "lt": "__lt__", + "lte": "__le__", +} + + +class QuerySet: + ESCAPE_CHARACTERS = ['%', '_'] + + def __init__(self, model_cls: 'Model' = None, filter_clauses: List = None, select_related: List = None, + limit_count: int = None, offset: int = None): + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + + def __get__(self, instance, owner): + return self.__class__(model_cls=owner) + + @property + def database(self): + return self.model_cls.__database__ + + @property + def table(self): + return self.model_cls.__table__ + + def build_select_expression(self): + tables = [self.table] + select_from = self.table + + for item in self._select_related: + model_cls = self.model_cls + select_from = self.table + for part in item.split("__"): + model_cls = model_cls.__model_fields__[part].to + select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) + tables.append(model_cls.__table__) + + expr = sqlalchemy.sql.select(tables) + expr = expr.select_from(select_from) + + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + print(expr.compile(compile_kwargs={"literal_binds": True})) + return expr + + def filter(self, **kwargs): + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + + if kwargs.get("pk"): + pk_name = self.model_cls.__pkname__ + kwargs[pk_name] = kwargs.pop("pk") + + for key, value in kwargs.items(): + if "__" in key: + parts = key.split("__") + + # Determine if we should treat the final part as a + # filter operator or as a related field. + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + model_cls = self.model_cls + if related_parts: + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + for part in related_parts: + model_cls = model_cls.__model_fields__[part].to + + column = model_cls.__table__.columns[field_name] + + else: + op = "exact" + column = self.table.columns[key] + + # Map the operation code onto SQLAlchemy's ColumnElement + # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement + op_attr = FILTER_OPERATORS[op] + has_escaped_character = False + + if op in ["contains", "icontains"]: + has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS + if c in value) + if has_escaped_character: + # enable escape modifier + for char in self.ESCAPE_CHARACTERS: + value = value.replace(char, f'\\{char}') + value = f"%{value}%" + + if isinstance(value, orm.Model): + value = value.pk + + clause = getattr(column, op_attr)(value) + clause.modifiers['escape'] = '\\' if has_escaped_character else None + filter_clauses.append(clause) + + return self.__class__( + model_cls=self.model_cls, + filter_clauses=filter_clauses, + select_related=select_related, + limit_count=self.limit_count, + offset=self.query_offset + ) + + def select_related(self, related): + if not isinstance(related, (list, tuple)): + related = [related] + + related = list(self._select_related) + related + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=related, + limit_count=self.limit_count, + offset=self.query_offset + ) + + async def exists(self) -> bool: + expr = self.build_select_expression() + expr = sqlalchemy.exists(expr).select() + return await self.database.fetch_val(expr) + + async def count(self) -> int: + expr = self.build_select_expression().alias("subquery_for_count") + expr = sqlalchemy.func.count().select().select_from(expr) + return await self.database.fetch_val(expr) + + def limit(self, limit_count: int): + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=limit_count, + offset=self.query_offset + ) + + def offset(self, offset: int): + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=offset + ) + + async def first(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).first() + + rows = await self.limit(1).all() + if rows: + return rows[0] + + async def get(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).get() + + expr = self.build_select_expression().limit(2) + rows = await self.database.fetch_all(expr) + + if not rows: + raise NoMatch() + if len(rows) > 1: + raise MultipleMatches() + return self.model_cls.from_row(rows[0], select_related=self._select_related) + + async def all(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).all() + + expr = self.build_select_expression() + rows = await self.database.fetch_all(expr) + return [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + + async def create(self, **kwargs): + + new_kwargs = dict(**kwargs) + + # Remove primary key when None to prevent not null constraint in postgresql. + pkname = self.model_cls.__pkname__ + pk = self.model_cls.__model_fields__[pkname] + if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): + del new_kwargs[pkname] + + # substitute related models with their pk + for field in self.model_cls.extract_related_names(): + if field in new_kwargs and new_kwargs.get(field) is not None: + new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) + + # Build the insert expression. + expr = self.table.insert() + expr = expr.values(**new_kwargs) + + # Execute the insert, and return a new model instance. + instance = self.model_cls(**kwargs) + instance.pk = await self.database.execute(expr) + return instance diff --git a/orm/relations.py b/orm/relations.py index 583f158..31bb520 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -1,20 +1,49 @@ +import pprint +import string +import uuid +from random import choices from typing import TYPE_CHECKING +from weakref import proxy + +from sqlalchemy import text + +from orm.fields import ForeignKey if TYPE_CHECKING: # pragma no cover from orm.models import Model +def get_table_alias(): + return ''.join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] + + +def get_relation_config(relation_type: str, table_name: str, field: ForeignKey): + alias = get_table_alias() + config = {'type': relation_type, + 'table_alias': alias, + 'source_table': table_name if relation_type == 'primary' else field.to.__tablename__, + 'target_table': field.to.__tablename__ if relation_type == 'primary' else table_name + } + return config + + class RelationshipManager: def __init__(self): self._relations = dict() - def add_relation_type(self, relations_key, reverse_key): + def add_relation_type(self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str): print(relations_key, reverse_key) if relations_key not in self._relations: - self._relations[relations_key] = {'type': 'primary'} + self._relations[relations_key] = get_relation_config('primary', table_name, field) if reverse_key not in self._relations: - self._relations[reverse_key] = {'type': 'reverse'} + self._relations[reverse_key] = get_relation_config('reverse', table_name, field) + + def deregister(self, model: 'Model'): + for rel_type in self._relations.keys(): + if model.__class__.__name__.lower() in rel_type.lower(): + if model._orm_id in self._relations[rel_type]: + del self._relations[rel_type][model._orm_id] def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False): parent_id = parent._orm_id @@ -22,9 +51,10 @@ class RelationshipManager: if virtual: child_name, parent_name = parent_name, child_name child_id, parent_id = parent_id, child_id - child, parent = parent, child - self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append( - child) + child, parent = parent, proxy(child) + else: + child = proxy(child) + self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(child) self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) def contains(self, relations_key: str, object: 'Model'): @@ -40,7 +70,7 @@ class RelationshipManager: return self._relations[relations_key][object._orm_id] def __str__(self): # pragma no cover - return ''.join(self._relations[rel].__str__() for rel in self._relations) + return pprint.pformat(self._relations, indent=4, width=1) def __repr__(self): # pragma no cover return self.__str__() diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 00c0672..8889064 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -13,7 +13,7 @@ metadata = sqlalchemy.MetaData() class Category(orm.Model): - __tablename__ = "cateries" + __tablename__ = "categories" __metadata__ = metadata __database__ = database @@ -22,7 +22,7 @@ class Category(orm.Model): class Item(orm.Model): - __tablename__ = "users" + __tablename__ = "items" __metadata__ = metadata __database__ = database diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index cadbfc8..155e6a4 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -19,7 +19,7 @@ class SchoolClass(orm.Model): class Category(orm.Model): - __tablename__ = "cateogories" + __tablename__ = "categories" __metadata__ = metadata __database__ = database @@ -75,3 +75,19 @@ async def test_model_multiple_instances_of_same_table_in_schema(): assert classes[0].students[0].schoolclass.name is None await classes[0].students[0].schoolclass.load() assert classes[0].students[0].schoolclass.name == 'Math' + + +@pytest.mark.asyncio +async def test_right_tables_join(): + async with database: + class1 = await SchoolClass.objects.create(name="Math") + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + + classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() + assert classes[0].name == 'Math' + assert classes[0].students[0].name == 'Jane' + breakpoint() + assert classes[0].teachers[0].category.name == 'Domestic'