From 475dafb6c9776c8e01cd8d2b42a31c09bc57a602 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 5 Aug 2020 18:32:13 +0200 Subject: [PATCH] rebuild the registry of relationships --- .coverage | Bin 53248 -> 53248 bytes orm/exceptions.py | 2 +- orm/fields.py | 45 ++++++++++++++----- orm/models.py | 72 +++++++++++++++++++----------- orm/relations.py | 72 +++++++++++++++--------------- tests/test_fastapi_usage.py | 8 +--- tests/test_foreign_keys.py | 9 +++- tests/test_same_table_joins.py | 77 +++++++++++++++++++++++++++++++++ 8 files changed, 204 insertions(+), 81 deletions(-) create mode 100644 tests/test_same_table_joins.py diff --git a/.coverage b/.coverage index e2e4542df14cf7fb1474b518ee080ef920a875a1..8cfa156b16a6247a46fef9e8643a14aac1b27338 100644 GIT binary patch delta 590 zcmZvaUr19?9LMkN?A~q4Ip5jCh#usM<;n<+m|L?y41#=7p?TDCb?vkq3fZS z9`xaY4^5a4#d-+IE)iL=hee>UN}N#0b|pxmbSowsvF18kuf6>~zjMCd4+l6uI^FKF+cj6)ibY{`Oh5{ca0i#s!uHu3i?N&REF&~UBXmq_u=G&L#)#0T zx_hkZ`)(#VC=tdvH6b{7V`$*k;MI|#f#K2X0qs*a5vAisB6w?;t-ThhMJ9r$Mp<6e zDiNW-M(*AOmj?w*zy=QuvIs{G+N-K0*3xq`3zYN+V?(0)FE=Ht`l3hT%gWy3h>DcGw0NzG4la2GOk0 zKz=JZk;&(CPW9)sLvF6)+b;FwOyktudNZF7nKKU^xw${Gi32m$)f=lzsB?BP^g3F~XQCHEFLUR=oe55MjP6IS&FclPu1VqL zsm^3F8z@fh730NXNi<2k^`PYU`@kO=<-RE>Xh{{=Kb(FFhi delta 473 zcmX|+Ur1A77>Cd4?412+-}jqcNQBH0v87%(MY4t^izq0jP!tR?@n+rF!h|4qaqO}n zg}zG(g^C5r2u|ij;w~11g0tAM5<4bcDAYVEi!)^5>1!9>{hpVH2i}oMX=GA*6K?jj z1zKAoZD%9lb6TW%(dQ>*N`Q`+7{+x3X@wSOnhw$qDiNJbX*IsbL=F*2xXGejwd`q( zGCgJ=2|Ay32jvDu65`dn+2xZPWl89%QnWkAyVC+v2;l~qraJwL75W;l=zV&gI>-v| zJ|S$V_w)knC)f;^JXTX_9)Xq(%F3fjTeU6|j$A*uXsAB8`5;agE1o=q^uv zqO3UDuAKBZME7OZFy>diZaz=At(k|V`~nmgC+hNCa>}Cnk;{2t&Xqc?+4q&Ty9Hw} z`LWeXmWR~!N@>fqjC9AH{Ic2EZpXfqFZ@ou6TzO zsm|8BWDdvVL?RJz@O9CV)Kp`;)-=jCTUsADGrpA2RtK!{9drIbWE*3r>b%GZsNfH_ zv56n}hR;|;4s)16hW|1OEi-W|;3vrb*DUOzh%Nkt!PNrtSilElx%eC`yBz%o>oc(- diff --git a/orm/exceptions.py b/orm/exceptions.py index 1a8c6d0..cb2100e 100644 --- a/orm/exceptions.py +++ b/orm/exceptions.py @@ -18,5 +18,5 @@ class MultipleMatches(AsyncOrmException): pass -class RelationshipNotFound(AsyncOrmException): +class RelationshipInstanceError(AsyncOrmException): pass diff --git a/orm/fields.py b/orm/fields.py index 5a28ecf..58fb4da 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -2,11 +2,11 @@ import datetime import decimal from typing import Optional, List -import pydantic import sqlalchemy +from pydantic import Json +from pydantic.fields import ModelField -from orm.exceptions import ModelDefinitionError -from orm.relations import Relationship +from orm.exceptions import ModelDefinitionError, RelationshipInstanceError class BaseField: @@ -79,7 +79,7 @@ class BaseField: def get_constraints(self) -> Optional[List]: return [] - def expand_relationship(self, value, parent): + def expand_relationship(self, value, child): return value @@ -145,7 +145,7 @@ class Time(BaseField): class JSON(BaseField): - __type__ = pydantic.Json + __type__ = Json def get_column_type(self): return sqlalchemy.JSON() @@ -173,8 +173,9 @@ class Decimal(BaseField): class ForeignKey(BaseField): - def __init__(self, to, related_name: str = None, nullable: bool = False): + def __init__(self, to, related_name: str = None, nullable: bool = False, virtual: bool = False): super().__init__(nullable=nullable) + self.virtual = virtual self.related_name = related_name self.to = to @@ -191,6 +192,9 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship(self, value, child): + if not isinstance(value, (self.to, dict, int, str)): + raise RelationshipInstanceError( + 'Relationship model can be build only from orm.Model, dict and integer or string (pk).') if isinstance(value, self.to): model = value elif isinstance(value, dict): @@ -199,10 +203,27 @@ class ForeignKey(BaseField): model = self.to(**{self.to.__pkname__: value}) child_model_name = self.related_name or child.__class__.__name__.lower() + 's' - model._orm_relationship_manager.add( - Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) - model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name, - type_=child.__pydantic_model__, - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__) + model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(), + child.__class__.__name__.lower(), + model, child, virtual=self.virtual) + + if child_model_name not in model.__fields__: + model.__fields__[child_model_name] = ModelField(name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__) + model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) + return model + + # def register_relationship(self): + # child_model_name = self.related_name or child.__class__.__name__.lower() + 's' + # if not child_model_name in model._orm_relationship_manager: + # model._orm_relationship_manager.add( + # Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) + # model.__fields__[child_model_name] = ModelField(name=child_model_name, + # type_=Optional[child.__pydantic_model__], + # model_config=child.__pydantic_model__.__config__, + # class_validators=child.__pydantic_model__.__validators__) + # model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) + # breakpoint() diff --git a/orm/models.py b/orm/models.py index ecc0937..0d93ee1 100644 --- a/orm/models.py +++ b/orm/models.py @@ -2,18 +2,19 @@ import copy import inspect import json import uuid -from abc import ABCMeta -from typing import Any, List, Type +from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar from typing import Set, Dict import pydantic import sqlalchemy -from pydantic import BaseConfig, create_model +from pydantic import BaseModel, BaseConfig, create_model -from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch -from orm.fields import BaseField +from orm.exceptions import ModelDefinitionError, NoMatch, MultipleMatches +from orm.fields import BaseField, ForeignKey from orm.relations import RelationshipManager +relationship_manager = RelationshipManager() + def parse_pydantic_field_from_model_fields(object_dict: dict): pydantic_fields = {field_name: ( @@ -25,6 +26,24 @@ def parse_pydantic_field_from_model_fields(object_dict: dict): return pydantic_fields +def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict): + pkname = None + columns: List[sqlalchemy.Column] = [] + model_fields: Dict[str, BaseField] = {} + + for field_name, field in object_dict.items(): + if isinstance(field, BaseField): + model_fields[field_name] = field + if not field.pydantic_only: + if field.primary_key: + pkname = field_name + if isinstance(field, ForeignKey): + reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's' + relationship_manager.add_relation_type(name + '_' + field.to.__name__.lower(), reverse_name) + columns.append(field.get_column(field_name)) + return pkname, columns, model_fields + + FILTER_OPERATORS = { "exact": "__eq__", "iexact": "ilike", @@ -272,19 +291,9 @@ class ModelMetaclass(type): tablename = attrs["__tablename__"] metadata = attrs["__metadata__"] - pkname = None - - columns = [] - model_fields = {} - for field_name, field in attrs.items(): - if isinstance(field, BaseField): - model_fields[field_name] = field - if not field.pydantic_only: - if field.primary_key: - pkname = field_name - columns.append(field.get_column(field_name)) # sqlalchemy table creation + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs) attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__columns__'] = columns attrs['__pkname__'] = pkname @@ -311,18 +320,28 @@ class ModelMetaclass(type): class Model(list, metaclass=ModelMetaclass): + # Model inherits from list in order to be treated as request.Body parameter in fastapi routes, + # inheriting from pydantic.BaseModel causes metaclass conflicts __abstract__ = True + if TYPE_CHECKING: # pragma no cover + __model_fields__: Dict[str, TypeVar[BaseField]] + __table__: sqlalchemy.Table + __fields__: Dict[str, pydantic.fields.ModelField] + __pydantic_model__: Type[BaseModel] + __pkname__: str objects = QuerySet() def __init__(self, *args, **kwargs) -> None: - self._orm_id = uuid.uuid4().hex - self._orm_saved = False - self._orm_relationship_manager = RelationshipManager(self) - self._orm_observers = [] + self._orm_id: str = uuid.uuid4().hex + self._orm_saved: bool = False + self._orm_relationship_manager: RelationshipManager = relationship_manager + self._orm_observers: List['Model'] = [] + self.values: Optional[BaseModel] = None if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") + # breakpoint() kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} self.values = self.__pydantic_model__(**kwargs) @@ -340,9 +359,9 @@ class Model(list, metaclass=ModelMetaclass): def __getattribute__(self, key: str) -> Any: if key != '__fields__' and key in self.__fields__: - if key in self._orm_relationship_manager: - parent_item = self._orm_relationship_manager.get(key) - return parent_item + relation_key = self.__class__.__name__.title() + '_' + key + if self._orm_relationship_manager.contains(relation_key, self): + return self._orm_relationship_manager.get(relation_key, self) item = getattr(self.values, key, None) if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): @@ -393,11 +412,12 @@ class Model(list, metaclass=ModelMetaclass): if column.name not in item: item[column.name] = row[column] + # breakpoint() return cls(**item) @classmethod - def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover - return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value) + def validate(cls, value: Any) -> 'BaseModel': # pragma no cover + return cls.__pydantic_model__.validate(value=value) @classmethod def __get_validators__(cls): # pragma no cover @@ -405,7 +425,7 @@ class Model(list, metaclass=ModelMetaclass): @classmethod def schema(cls, by_alias: bool = True): # pragma no cover - return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias) + return cls.__pydantic_model__.schema(by_alias=by_alias) def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json diff --git a/orm/relations.py b/orm/relations.py index 8e5f3dd..583f158 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -1,44 +1,46 @@ -from typing import Dict, Union, List +from typing import TYPE_CHECKING -from orm.exceptions import RelationshipNotFound - - -class Relationship: - - def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'): - self.fk_side = fk_side - self.child = child - self.parent = parent - self.name = name +if TYPE_CHECKING: # pragma no cover + from orm.models import Model class RelationshipManager: - def __init__(self, model: 'Model'): - self._orm_id: str = model._orm_id - self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict() + def __init__(self): + self._relations = dict() - def __contains__(self, item): - return item in self._relations + def add_relation_type(self, relations_key, reverse_key): + print(relations_key, reverse_key) + if relations_key not in self._relations: + self._relations[relations_key] = {'type': 'primary'} + if reverse_key not in self._relations: + self._relations[reverse_key] = {'type': 'reverse'} - def add_related(self, relation: Relationship): - if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id: - new_relation = Relationship(name=relation.parent.__class__.__name__.lower(), - child=relation.parent, - parent=relation.child, - fk_side='parent') - relation.child._orm_relationship_manager.add(new_relation) + def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False): + parent_id = parent._orm_id + child_id = child._orm_id + if virtual: + child_name, parent_name = parent_name, child_name + child_id, parent_id = parent_id, child_id + child, parent = parent, child + self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append( + child) + self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) - def add(self, relation: Relationship): - if relation.name in self._relations: - self._relations[relation.name].append(relation) - else: - self._relations[relation.name] = [relation] - self.add_related(relation) + def contains(self, relations_key: str, object: 'Model'): + if relations_key in self._relations: + return object._orm_id in self._relations[relations_key] + return False - def get(self, name: str): - for rel, relations in self._relations.items(): - if rel == name: - if relations and relations[0].fk_side == 'parent': - return relations[0].child - return [rela.child for rela in relations] + def get(self, relations_key: str, object: 'Model'): + if relations_key in self._relations: + if object._orm_id in self._relations[relations_key]: + if self._relations[relations_key]['type'] == 'primary': + return self._relations[relations_key][object._orm_id][0] + return self._relations[relations_key][object._orm_id] + + def __str__(self): # pragma no cover + return ''.join(self._relations[rel].__str__() for rel in self._relations) + + def __repr__(self): # pragma no cover + return self.__str__() diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 67a4c5f..00c0672 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -1,17 +1,13 @@ -import json -from typing import Optional - import databases -import pydantic import sqlalchemy from fastapi import FastAPI from fastapi.testclient import TestClient -app = FastAPI() - import orm from tests.settings import DATABASE_URL +app = FastAPI() + database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index d3d95bf..c222cfa 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -3,7 +3,7 @@ import pytest import sqlalchemy import orm -from orm.exceptions import NoMatch, MultipleMatches +from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -229,3 +229,10 @@ async def test_get_exceptions(): await Track.objects.create(album=fantasies, title="Test3", position=3) with pytest.raises(MultipleMatches): await Track.objects.select_related("album").get(album=fantasies) + + +@pytest.mark.asyncio +async def test_wrong_model_passed_as_fk(): + with pytest.raises(RelationshipInstanceError): + org = await Organisation.objects.create(ident="ACME Ltd") + await Track.objects.create(album=org, title="Test1", position=1) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py new file mode 100644 index 0000000..cadbfc8 --- /dev/null +++ b/tests/test_same_table_joins.py @@ -0,0 +1,77 @@ +import databases +import pytest +import sqlalchemy + +import orm +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class SchoolClass(orm.Model): + __tablename__ = "schoolclasses" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Category(orm.Model): + __tablename__ = "cateogories" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Student(orm.Model): + __tablename__ = "students" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + schoolclass = orm.ForeignKey(SchoolClass) + category = orm.ForeignKey(Category, nullable=True) + + +class Teacher(orm.Model): + __tablename__ = "teachers" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + schoolclass = orm.ForeignKey(SchoolClass) + category = orm.ForeignKey(Category, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_multiple_instances_of_same_table_in_schema(): + async with database: + class1 = await SchoolClass.objects.create(name="Math") + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + + classes = await SchoolClass.objects.select_related(['teachers', 'students']).all() + assert classes[0].name == 'Math' + assert classes[0].students[0].name == 'Jane' + + # related fields of main model are only populated by pk + # but you can load them anytime + assert classes[0].students[0].schoolclass.name is None + await classes[0].students[0].schoolclass.load() + assert classes[0].students[0].schoolclass.name == 'Math'