rebuild the registry of relationships

This commit is contained in:
collerek
2020-08-05 18:32:13 +02:00
parent a371c48959
commit 475dafb6c9
8 changed files with 204 additions and 81 deletions

BIN
.coverage

Binary file not shown.

View File

@ -18,5 +18,5 @@ class MultipleMatches(AsyncOrmException):
pass
class RelationshipNotFound(AsyncOrmException):
class RelationshipInstanceError(AsyncOrmException):
pass

View File

@ -2,11 +2,11 @@ import datetime
import decimal
from typing import Optional, List
import pydantic
import sqlalchemy
from pydantic import Json
from pydantic.fields import ModelField
from orm.exceptions import ModelDefinitionError
from orm.relations import Relationship
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
class BaseField:
@ -79,7 +79,7 @@ class BaseField:
def get_constraints(self) -> Optional[List]:
return []
def expand_relationship(self, value, parent):
def expand_relationship(self, value, child):
return value
@ -145,7 +145,7 @@ class Time(BaseField):
class JSON(BaseField):
__type__ = pydantic.Json
__type__ = Json
def get_column_type(self):
return sqlalchemy.JSON()
@ -173,8 +173,9 @@ class Decimal(BaseField):
class ForeignKey(BaseField):
def __init__(self, to, related_name: str = None, nullable: bool = False):
def __init__(self, to, related_name: str = None, nullable: bool = False, virtual: bool = False):
super().__init__(nullable=nullable)
self.virtual = virtual
self.related_name = related_name
self.to = to
@ -191,6 +192,9 @@ class ForeignKey(BaseField):
return to_column.get_column_type()
def expand_relationship(self, value, child):
if not isinstance(value, (self.to, dict, int, str)):
raise RelationshipInstanceError(
'Relationship model can be build only from orm.Model, dict and integer or string (pk).')
if isinstance(value, self.to):
model = value
elif isinstance(value, dict):
@ -199,10 +203,27 @@ class ForeignKey(BaseField):
model = self.to(**{self.to.__pkname__: value})
child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
model._orm_relationship_manager.add(
Relationship(name=child_model_name, child=child, parent=model, fk_side='child'))
model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name,
type_=child.__pydantic_model__,
model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(),
child.__class__.__name__.lower(),
model, child, virtual=self.virtual)
if child_model_name not in model.__fields__:
model.__fields__[child_model_name] = ModelField(name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__)
model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True)
return model
# def register_relationship(self):
# child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
# if not child_model_name in model._orm_relationship_manager:
# model._orm_relationship_manager.add(
# Relationship(name=child_model_name, child=child, parent=model, fk_side='child'))
# model.__fields__[child_model_name] = ModelField(name=child_model_name,
# type_=Optional[child.__pydantic_model__],
# model_config=child.__pydantic_model__.__config__,
# class_validators=child.__pydantic_model__.__validators__)
# model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True)
# breakpoint()

View File

@ -2,18 +2,19 @@ import copy
import inspect
import json
import uuid
from abc import ABCMeta
from typing import Any, List, Type
from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar
from typing import Set, Dict
import pydantic
import sqlalchemy
from pydantic import BaseConfig, create_model
from pydantic import BaseModel, BaseConfig, create_model
from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch
from orm.fields import BaseField
from orm.exceptions import ModelDefinitionError, NoMatch, MultipleMatches
from orm.fields import BaseField, ForeignKey
from orm.relations import RelationshipManager
relationship_manager = RelationshipManager()
def parse_pydantic_field_from_model_fields(object_dict: dict):
pydantic_fields = {field_name: (
@ -25,6 +26,24 @@ def parse_pydantic_field_from_model_fields(object_dict: dict):
return pydantic_fields
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict):
pkname = None
columns: List[sqlalchemy.Column] = []
model_fields: Dict[str, BaseField] = {}
for field_name, field in object_dict.items():
if isinstance(field, BaseField):
model_fields[field_name] = field
if not field.pydantic_only:
if field.primary_key:
pkname = field_name
if isinstance(field, ForeignKey):
reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's'
relationship_manager.add_relation_type(name + '_' + field.to.__name__.lower(), reverse_name)
columns.append(field.get_column(field_name))
return pkname, columns, model_fields
FILTER_OPERATORS = {
"exact": "__eq__",
"iexact": "ilike",
@ -272,19 +291,9 @@ class ModelMetaclass(type):
tablename = attrs["__tablename__"]
metadata = attrs["__metadata__"]
pkname = None
columns = []
model_fields = {}
for field_name, field in attrs.items():
if isinstance(field, BaseField):
model_fields[field_name] = field
if not field.pydantic_only:
if field.primary_key:
pkname = field_name
columns.append(field.get_column(field_name))
# sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs)
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
attrs['__columns__'] = columns
attrs['__pkname__'] = pkname
@ -311,18 +320,28 @@ class ModelMetaclass(type):
class Model(list, metaclass=ModelMetaclass):
# Model inherits from list in order to be treated as request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts
__abstract__ = True
if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]]
__table__: sqlalchemy.Table
__fields__: Dict[str, pydantic.fields.ModelField]
__pydantic_model__: Type[BaseModel]
__pkname__: str
objects = QuerySet()
def __init__(self, *args, **kwargs) -> None:
self._orm_id = uuid.uuid4().hex
self._orm_saved = False
self._orm_relationship_manager = RelationshipManager(self)
self._orm_observers = []
self._orm_id: str = uuid.uuid4().hex
self._orm_saved: bool = False
self._orm_relationship_manager: RelationshipManager = relationship_manager
self._orm_observers: List['Model'] = []
self.values: Optional[BaseModel] = None
if "pk" in kwargs:
kwargs[self.__pkname__] = kwargs.pop("pk")
# breakpoint()
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
self.values = self.__pydantic_model__(**kwargs)
@ -340,9 +359,9 @@ class Model(list, metaclass=ModelMetaclass):
def __getattribute__(self, key: str) -> Any:
if key != '__fields__' and key in self.__fields__:
if key in self._orm_relationship_manager:
parent_item = self._orm_relationship_manager.get(key)
return parent_item
relation_key = self.__class__.__name__.title() + '_' + key
if self._orm_relationship_manager.contains(relation_key, self):
return self._orm_relationship_manager.get(relation_key, self)
item = getattr(self.values, key, None)
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
@ -393,11 +412,12 @@ class Model(list, metaclass=ModelMetaclass):
if column.name not in item:
item[column.name] = row[column]
# breakpoint()
return cls(**item)
@classmethod
def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover
return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value)
def validate(cls, value: Any) -> 'BaseModel': # pragma no cover
return cls.__pydantic_model__.validate(value=value)
@classmethod
def __get_validators__(cls): # pragma no cover
@ -405,7 +425,7 @@ class Model(list, metaclass=ModelMetaclass):
@classmethod
def schema(cls, by_alias: bool = True): # pragma no cover
return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias)
return cls.__pydantic_model__.schema(by_alias=by_alias)
def is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.__model_fields__.get(column_name).__type__ == pydantic.Json

View File

@ -1,44 +1,46 @@
from typing import Dict, Union, List
from typing import TYPE_CHECKING
from orm.exceptions import RelationshipNotFound
class Relationship:
def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'):
self.fk_side = fk_side
self.child = child
self.parent = parent
self.name = name
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
class RelationshipManager:
def __init__(self, model: 'Model'):
self._orm_id: str = model._orm_id
self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict()
def __init__(self):
self._relations = dict()
def __contains__(self, item):
return item in self._relations
def add_relation_type(self, relations_key, reverse_key):
print(relations_key, reverse_key)
if relations_key not in self._relations:
self._relations[relations_key] = {'type': 'primary'}
if reverse_key not in self._relations:
self._relations[reverse_key] = {'type': 'reverse'}
def add_related(self, relation: Relationship):
if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id:
new_relation = Relationship(name=relation.parent.__class__.__name__.lower(),
child=relation.parent,
parent=relation.child,
fk_side='parent')
relation.child._orm_relationship_manager.add(new_relation)
def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False):
parent_id = parent._orm_id
child_id = child._orm_id
if virtual:
child_name, parent_name = parent_name, child_name
child_id, parent_id = parent_id, child_id
child, parent = parent, child
self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(
child)
self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent)
def add(self, relation: Relationship):
if relation.name in self._relations:
self._relations[relation.name].append(relation)
else:
self._relations[relation.name] = [relation]
self.add_related(relation)
def contains(self, relations_key: str, object: 'Model'):
if relations_key in self._relations:
return object._orm_id in self._relations[relations_key]
return False
def get(self, name: str):
for rel, relations in self._relations.items():
if rel == name:
if relations and relations[0].fk_side == 'parent':
return relations[0].child
return [rela.child for rela in relations]
def get(self, relations_key: str, object: 'Model'):
if relations_key in self._relations:
if object._orm_id in self._relations[relations_key]:
if self._relations[relations_key]['type'] == 'primary':
return self._relations[relations_key][object._orm_id][0]
return self._relations[relations_key][object._orm_id]
def __str__(self): # pragma no cover
return ''.join(self._relations[rel].__str__() for rel in self._relations)
def __repr__(self): # pragma no cover
return self.__str__()

View File

@ -1,17 +1,13 @@
import json
from typing import Optional
import databases
import pydantic
import sqlalchemy
from fastapi import FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
import orm
from tests.settings import DATABASE_URL
app = FastAPI()
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()

View File

@ -3,7 +3,7 @@ import pytest
import sqlalchemy
import orm
from orm.exceptions import NoMatch, MultipleMatches
from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
@ -229,3 +229,10 @@ async def test_get_exceptions():
await Track.objects.create(album=fantasies, title="Test3", position=3)
with pytest.raises(MultipleMatches):
await Track.objects.select_related("album").get(album=fantasies)
@pytest.mark.asyncio
async def test_wrong_model_passed_as_fk():
with pytest.raises(RelationshipInstanceError):
org = await Organisation.objects.create(ident="ACME Ltd")
await Track.objects.create(album=org, title="Test1", position=1)

View File

@ -0,0 +1,77 @@
import databases
import pytest
import sqlalchemy
import orm
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class SchoolClass(orm.Model):
__tablename__ = "schoolclasses"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
class Category(orm.Model):
__tablename__ = "cateogories"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
class Student(orm.Model):
__tablename__ = "students"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
schoolclass = orm.ForeignKey(SchoolClass)
category = orm.ForeignKey(Category, nullable=True)
class Teacher(orm.Model):
__tablename__ = "teachers"
__metadata__ = metadata
__database__ = database
id = orm.Integer(primary_key=True)
name = orm.String(length=100)
schoolclass = orm.ForeignKey(SchoolClass)
category = orm.ForeignKey(Category, nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema():
async with database:
class1 = await SchoolClass.objects.create(name="Math")
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
classes = await SchoolClass.objects.select_related(['teachers', 'students']).all()
assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane'
# related fields of main model are only populated by pk
# but you can load them anytime
assert classes[0].students[0].schoolclass.name is None
await classes[0].students[0].schoolclass.load()
assert classes[0].students[0].schoolclass.name == 'Math'