rebuild the registry of relationships

This commit is contained in:
collerek
2020-08-05 18:32:13 +02:00
parent a371c48959
commit 475dafb6c9
8 changed files with 204 additions and 81 deletions

BIN
.coverage

Binary file not shown.

View File

@ -18,5 +18,5 @@ class MultipleMatches(AsyncOrmException):
pass pass
class RelationshipNotFound(AsyncOrmException): class RelationshipInstanceError(AsyncOrmException):
pass pass

View File

@ -2,11 +2,11 @@ import datetime
import decimal import decimal
from typing import Optional, List from typing import Optional, List
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Json
from pydantic.fields import ModelField
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
from orm.relations import Relationship
class BaseField: class BaseField:
@ -79,7 +79,7 @@ class BaseField:
def get_constraints(self) -> Optional[List]: def get_constraints(self) -> Optional[List]:
return [] return []
def expand_relationship(self, value, parent): def expand_relationship(self, value, child):
return value return value
@ -145,7 +145,7 @@ class Time(BaseField):
class JSON(BaseField): class JSON(BaseField):
__type__ = pydantic.Json __type__ = Json
def get_column_type(self): def get_column_type(self):
return sqlalchemy.JSON() return sqlalchemy.JSON()
@ -173,8 +173,9 @@ class Decimal(BaseField):
class ForeignKey(BaseField): class ForeignKey(BaseField):
def __init__(self, to, related_name: str = None, nullable: bool = False): def __init__(self, to, related_name: str = None, nullable: bool = False, virtual: bool = False):
super().__init__(nullable=nullable) super().__init__(nullable=nullable)
self.virtual = virtual
self.related_name = related_name self.related_name = related_name
self.to = to self.to = to
@ -191,6 +192,9 @@ class ForeignKey(BaseField):
return to_column.get_column_type() return to_column.get_column_type()
def expand_relationship(self, value, child): def expand_relationship(self, value, child):
if not isinstance(value, (self.to, dict, int, str)):
raise RelationshipInstanceError(
'Relationship model can be build only from orm.Model, dict and integer or string (pk).')
if isinstance(value, self.to): if isinstance(value, self.to):
model = value model = value
elif isinstance(value, dict): elif isinstance(value, dict):
@ -199,10 +203,27 @@ class ForeignKey(BaseField):
model = self.to(**{self.to.__pkname__: value}) model = self.to(**{self.to.__pkname__: value})
child_model_name = self.related_name or child.__class__.__name__.lower() + 's' child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
model._orm_relationship_manager.add( model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(),
Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) child.__class__.__name__.lower(),
model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name, model, child, virtual=self.virtual)
type_=child.__pydantic_model__,
if child_model_name not in model.__fields__:
model.__fields__[child_model_name] = ModelField(name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__, model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__) class_validators=child.__pydantic_model__.__validators__)
model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True)
return model return model
# def register_relationship(self):
# child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
# if not child_model_name in model._orm_relationship_manager:
# model._orm_relationship_manager.add(
# Relationship(name=child_model_name, child=child, parent=model, fk_side='child'))
# model.__fields__[child_model_name] = ModelField(name=child_model_name,
# type_=Optional[child.__pydantic_model__],
# model_config=child.__pydantic_model__.__config__,
# class_validators=child.__pydantic_model__.__validators__)
# model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True)
# breakpoint()

View File

@ -2,18 +2,19 @@ import copy
import inspect import inspect
import json import json
import uuid import uuid
from abc import ABCMeta from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar
from typing import Any, List, Type
from typing import Set, Dict from typing import Set, Dict
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig, create_model from pydantic import BaseModel, BaseConfig, create_model
from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch from orm.exceptions import ModelDefinitionError, NoMatch, MultipleMatches
from orm.fields import BaseField from orm.fields import BaseField, ForeignKey
from orm.relations import RelationshipManager from orm.relations import RelationshipManager
relationship_manager = RelationshipManager()
def parse_pydantic_field_from_model_fields(object_dict: dict): def parse_pydantic_field_from_model_fields(object_dict: dict):
pydantic_fields = {field_name: ( pydantic_fields = {field_name: (
@ -25,6 +26,24 @@ def parse_pydantic_field_from_model_fields(object_dict: dict):
return pydantic_fields return pydantic_fields
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict):
pkname = None
columns: List[sqlalchemy.Column] = []
model_fields: Dict[str, BaseField] = {}
for field_name, field in object_dict.items():
if isinstance(field, BaseField):
model_fields[field_name] = field
if not field.pydantic_only:
if field.primary_key:
pkname = field_name
if isinstance(field, ForeignKey):
reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's'
relationship_manager.add_relation_type(name + '_' + field.to.__name__.lower(), reverse_name)
columns.append(field.get_column(field_name))
return pkname, columns, model_fields
FILTER_OPERATORS = { FILTER_OPERATORS = {
"exact": "__eq__", "exact": "__eq__",
"iexact": "ilike", "iexact": "ilike",
@ -272,19 +291,9 @@ class ModelMetaclass(type):
tablename = attrs["__tablename__"] tablename = attrs["__tablename__"]
metadata = attrs["__metadata__"] metadata = attrs["__metadata__"]
pkname = None
columns = []
model_fields = {}
for field_name, field in attrs.items():
if isinstance(field, BaseField):
model_fields[field_name] = field
if not field.pydantic_only:
if field.primary_key:
pkname = field_name
columns.append(field.get_column(field_name))
# sqlalchemy table creation # sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs)
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
attrs['__columns__'] = columns attrs['__columns__'] = columns
attrs['__pkname__'] = pkname attrs['__pkname__'] = pkname
@ -311,18 +320,28 @@ class ModelMetaclass(type):
class Model(list, metaclass=ModelMetaclass): class Model(list, metaclass=ModelMetaclass):
# Model inherits from list in order to be treated as request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts
__abstract__ = True __abstract__ = True
if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]]
__table__: sqlalchemy.Table
__fields__: Dict[str, pydantic.fields.ModelField]
__pydantic_model__: Type[BaseModel]
__pkname__: str
objects = QuerySet() objects = QuerySet()
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
self._orm_id = uuid.uuid4().hex self._orm_id: str = uuid.uuid4().hex
self._orm_saved = False self._orm_saved: bool = False
self._orm_relationship_manager = RelationshipManager(self) self._orm_relationship_manager: RelationshipManager = relationship_manager
self._orm_observers = [] self._orm_observers: List['Model'] = []
self.values: Optional[BaseModel] = None
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk") kwargs[self.__pkname__] = kwargs.pop("pk")
# breakpoint()
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
self.values = self.__pydantic_model__(**kwargs) self.values = self.__pydantic_model__(**kwargs)
@ -340,9 +359,9 @@ class Model(list, metaclass=ModelMetaclass):
def __getattribute__(self, key: str) -> Any: def __getattribute__(self, key: str) -> Any:
if key != '__fields__' and key in self.__fields__: if key != '__fields__' and key in self.__fields__:
if key in self._orm_relationship_manager: relation_key = self.__class__.__name__.title() + '_' + key
parent_item = self._orm_relationship_manager.get(key) if self._orm_relationship_manager.contains(relation_key, self):
return parent_item return self._orm_relationship_manager.get(relation_key, self)
item = getattr(self.values, key, None) item = getattr(self.values, key, None)
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
@ -393,11 +412,12 @@ class Model(list, metaclass=ModelMetaclass):
if column.name not in item: if column.name not in item:
item[column.name] = row[column] item[column.name] = row[column]
# breakpoint()
return cls(**item) return cls(**item)
@classmethod @classmethod
def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover def validate(cls, value: Any) -> 'BaseModel': # pragma no cover
return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value) return cls.__pydantic_model__.validate(value=value)
@classmethod @classmethod
def __get_validators__(cls): # pragma no cover def __get_validators__(cls): # pragma no cover
@ -405,7 +425,7 @@ class Model(list, metaclass=ModelMetaclass):
@classmethod @classmethod
def schema(cls, by_alias: bool = True): # pragma no cover def schema(cls, by_alias: bool = True): # pragma no cover
return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias) return cls.__pydantic_model__.schema(by_alias=by_alias)
def is_conversion_to_json_needed(self, column_name: str) -> bool: def is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json return self.__model_fields__.get(column_name).__type__ == pydantic.Json

View File

@ -1,44 +1,46 @@
from typing import Dict, Union, List from typing import TYPE_CHECKING
from orm.exceptions import RelationshipNotFound if TYPE_CHECKING: # pragma no cover
from orm.models import Model
class Relationship:
def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'):
self.fk_side = fk_side
self.child = child
self.parent = parent
self.name = name
class RelationshipManager: class RelationshipManager:
def __init__(self, model: 'Model'): def __init__(self):
self._orm_id: str = model._orm_id self._relations = dict()
self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict()
def __contains__(self, item): def add_relation_type(self, relations_key, reverse_key):
return item in self._relations print(relations_key, reverse_key)
if relations_key not in self._relations:
self._relations[relations_key] = {'type': 'primary'}
if reverse_key not in self._relations:
self._relations[reverse_key] = {'type': 'reverse'}
def add_related(self, relation: Relationship): def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False):
if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id: parent_id = parent._orm_id
new_relation = Relationship(name=relation.parent.__class__.__name__.lower(), child_id = child._orm_id
child=relation.parent, if virtual:
parent=relation.child, child_name, parent_name = parent_name, child_name
fk_side='parent') child_id, parent_id = parent_id, child_id
relation.child._orm_relationship_manager.add(new_relation) child, parent = parent, child
self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(
child)
self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent)
def add(self, relation: Relationship): def contains(self, relations_key: str, object: 'Model'):
if relation.name in self._relations: if relations_key in self._relations:
self._relations[relation.name].append(relation) return object._orm_id in self._relations[relations_key]
else: return False
self._relations[relation.name] = [relation]
self.add_related(relation)
def get(self, name: str): def get(self, relations_key: str, object: 'Model'):
for rel, relations in self._relations.items(): if relations_key in self._relations:
if rel == name: if object._orm_id in self._relations[relations_key]:
if relations and relations[0].fk_side == 'parent': if self._relations[relations_key]['type'] == 'primary':
return relations[0].child return self._relations[relations_key][object._orm_id][0]
return [rela.child for rela in relations] return self._relations[relations_key][object._orm_id]
def __str__(self): # pragma no cover
return ''.join(self._relations[rel].__str__() for rel in self._relations)
def __repr__(self): # pragma no cover
return self.__str__()

View File

@ -1,17 +1,13 @@
import json
from typing import Optional
import databases import databases
import pydantic
import sqlalchemy import sqlalchemy
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
app = FastAPI()
import orm import orm
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
app = FastAPI()
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()

View File

@ -3,7 +3,7 @@ import pytest
import sqlalchemy import sqlalchemy
import orm import orm
from orm.exceptions import NoMatch, MultipleMatches from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -229,3 +229,10 @@ async def test_get_exceptions():
await Track.objects.create(album=fantasies, title="Test3", position=3) await Track.objects.create(album=fantasies, title="Test3", position=3)
with pytest.raises(MultipleMatches): with pytest.raises(MultipleMatches):
await Track.objects.select_related("album").get(album=fantasies) await Track.objects.select_related("album").get(album=fantasies)
@pytest.mark.asyncio
async def test_wrong_model_passed_as_fk():
with pytest.raises(RelationshipInstanceError):
org = await Organisation.objects.create(ident="ACME Ltd")
await Track.objects.create(album=org, title="Test1", position=1)

View File

@ -0,0 +1,77 @@
import databases
import pytest
import sqlalchemy
import orm
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class SchoolClass(orm.Model):
__tablename__ = "schoolclasses"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
class Category(orm.Model):
__tablename__ = "cateogories"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
class Student(orm.Model):
__tablename__ = "students"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
schoolclass = orm.ForeignKey(SchoolClass)
category = orm.ForeignKey(Category, nullable=True)
class Teacher(orm.Model):
__tablename__ = "teachers"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
schoolclass = orm.ForeignKey(SchoolClass)
category = orm.ForeignKey(Category, nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema():
async with database:
class1 = await SchoolClass.objects.create(name="Math")
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
classes = await SchoolClass.objects.select_related(['teachers', 'students']).all()
assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane'
# related fields of main model are only populated by pk
# but you can load them anytime
assert classes[0].students[0].schoolclass.name is None
await classes[0].students[0].schoolclass.load()
assert classes[0].students[0].schoolclass.name == 'Math'