rebuild the registry of relationships
This commit is contained in:
@ -18,5 +18,5 @@ class MultipleMatches(AsyncOrmException):
|
||||
pass
|
||||
|
||||
|
||||
class RelationshipNotFound(AsyncOrmException):
|
||||
class RelationshipInstanceError(AsyncOrmException):
|
||||
pass
|
||||
|
||||
@ -2,11 +2,11 @@ import datetime
|
||||
import decimal
|
||||
from typing import Optional, List
|
||||
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import Json
|
||||
from pydantic.fields import ModelField
|
||||
|
||||
from orm.exceptions import ModelDefinitionError
|
||||
from orm.relations import Relationship
|
||||
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
|
||||
|
||||
|
||||
class BaseField:
|
||||
@ -79,7 +79,7 @@ class BaseField:
|
||||
def get_constraints(self) -> Optional[List]:
|
||||
return []
|
||||
|
||||
def expand_relationship(self, value, parent):
|
||||
def expand_relationship(self, value, child):
|
||||
return value
|
||||
|
||||
|
||||
@ -145,7 +145,7 @@ class Time(BaseField):
|
||||
|
||||
|
||||
class JSON(BaseField):
|
||||
__type__ = pydantic.Json
|
||||
__type__ = Json
|
||||
|
||||
def get_column_type(self):
|
||||
return sqlalchemy.JSON()
|
||||
@ -173,8 +173,9 @@ class Decimal(BaseField):
|
||||
|
||||
|
||||
class ForeignKey(BaseField):
|
||||
def __init__(self, to, related_name: str = None, nullable: bool = False):
|
||||
def __init__(self, to, related_name: str = None, nullable: bool = False, virtual: bool = False):
|
||||
super().__init__(nullable=nullable)
|
||||
self.virtual = virtual
|
||||
self.related_name = related_name
|
||||
self.to = to
|
||||
|
||||
@ -191,6 +192,9 @@ class ForeignKey(BaseField):
|
||||
return to_column.get_column_type()
|
||||
|
||||
def expand_relationship(self, value, child):
|
||||
if not isinstance(value, (self.to, dict, int, str)):
|
||||
raise RelationshipInstanceError(
|
||||
'Relationship model can be build only from orm.Model, dict and integer or string (pk).')
|
||||
if isinstance(value, self.to):
|
||||
model = value
|
||||
elif isinstance(value, dict):
|
||||
@ -199,10 +203,27 @@ class ForeignKey(BaseField):
|
||||
model = self.to(**{self.to.__pkname__: value})
|
||||
|
||||
child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
|
||||
model._orm_relationship_manager.add(
|
||||
Relationship(name=child_model_name, child=child, parent=model, fk_side='child'))
|
||||
model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name,
|
||||
type_=child.__pydantic_model__,
|
||||
model_config=child.__pydantic_model__.__config__,
|
||||
class_validators=child.__pydantic_model__.__validators__)
|
||||
model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(),
|
||||
child.__class__.__name__.lower(),
|
||||
model, child, virtual=self.virtual)
|
||||
|
||||
if child_model_name not in model.__fields__:
|
||||
model.__fields__[child_model_name] = ModelField(name=child_model_name,
|
||||
type_=Optional[child.__pydantic_model__],
|
||||
model_config=child.__pydantic_model__.__config__,
|
||||
class_validators=child.__pydantic_model__.__validators__)
|
||||
model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True)
|
||||
|
||||
return model
|
||||
|
||||
# def register_relationship(self):
|
||||
# child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
|
||||
# if not child_model_name in model._orm_relationship_manager:
|
||||
# model._orm_relationship_manager.add(
|
||||
# Relationship(name=child_model_name, child=child, parent=model, fk_side='child'))
|
||||
# model.__fields__[child_model_name] = ModelField(name=child_model_name,
|
||||
# type_=Optional[child.__pydantic_model__],
|
||||
# model_config=child.__pydantic_model__.__config__,
|
||||
# class_validators=child.__pydantic_model__.__validators__)
|
||||
# model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True)
|
||||
# breakpoint()
|
||||
|
||||
@ -2,18 +2,19 @@ import copy
|
||||
import inspect
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABCMeta
|
||||
from typing import Any, List, Type
|
||||
from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar
|
||||
from typing import Set, Dict
|
||||
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from pydantic import BaseConfig, create_model
|
||||
from pydantic import BaseModel, BaseConfig, create_model
|
||||
|
||||
from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch
|
||||
from orm.fields import BaseField
|
||||
from orm.exceptions import ModelDefinitionError, NoMatch, MultipleMatches
|
||||
from orm.fields import BaseField, ForeignKey
|
||||
from orm.relations import RelationshipManager
|
||||
|
||||
relationship_manager = RelationshipManager()
|
||||
|
||||
|
||||
def parse_pydantic_field_from_model_fields(object_dict: dict):
|
||||
pydantic_fields = {field_name: (
|
||||
@ -25,6 +26,24 @@ def parse_pydantic_field_from_model_fields(object_dict: dict):
|
||||
return pydantic_fields
|
||||
|
||||
|
||||
def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict):
|
||||
pkname = None
|
||||
columns: List[sqlalchemy.Column] = []
|
||||
model_fields: Dict[str, BaseField] = {}
|
||||
|
||||
for field_name, field in object_dict.items():
|
||||
if isinstance(field, BaseField):
|
||||
model_fields[field_name] = field
|
||||
if not field.pydantic_only:
|
||||
if field.primary_key:
|
||||
pkname = field_name
|
||||
if isinstance(field, ForeignKey):
|
||||
reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's'
|
||||
relationship_manager.add_relation_type(name + '_' + field.to.__name__.lower(), reverse_name)
|
||||
columns.append(field.get_column(field_name))
|
||||
return pkname, columns, model_fields
|
||||
|
||||
|
||||
FILTER_OPERATORS = {
|
||||
"exact": "__eq__",
|
||||
"iexact": "ilike",
|
||||
@ -272,19 +291,9 @@ class ModelMetaclass(type):
|
||||
|
||||
tablename = attrs["__tablename__"]
|
||||
metadata = attrs["__metadata__"]
|
||||
pkname = None
|
||||
|
||||
columns = []
|
||||
model_fields = {}
|
||||
for field_name, field in attrs.items():
|
||||
if isinstance(field, BaseField):
|
||||
model_fields[field_name] = field
|
||||
if not field.pydantic_only:
|
||||
if field.primary_key:
|
||||
pkname = field_name
|
||||
columns.append(field.get_column(field_name))
|
||||
|
||||
# sqlalchemy table creation
|
||||
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs)
|
||||
attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns)
|
||||
attrs['__columns__'] = columns
|
||||
attrs['__pkname__'] = pkname
|
||||
@ -311,18 +320,28 @@ class ModelMetaclass(type):
|
||||
|
||||
|
||||
class Model(list, metaclass=ModelMetaclass):
|
||||
# Model inherits from list in order to be treated as request.Body parameter in fastapi routes,
|
||||
# inheriting from pydantic.BaseModel causes metaclass conflicts
|
||||
__abstract__ = True
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
__model_fields__: Dict[str, TypeVar[BaseField]]
|
||||
__table__: sqlalchemy.Table
|
||||
__fields__: Dict[str, pydantic.fields.ModelField]
|
||||
__pydantic_model__: Type[BaseModel]
|
||||
__pkname__: str
|
||||
|
||||
objects = QuerySet()
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self._orm_id = uuid.uuid4().hex
|
||||
self._orm_saved = False
|
||||
self._orm_relationship_manager = RelationshipManager(self)
|
||||
self._orm_observers = []
|
||||
self._orm_id: str = uuid.uuid4().hex
|
||||
self._orm_saved: bool = False
|
||||
self._orm_relationship_manager: RelationshipManager = relationship_manager
|
||||
self._orm_observers: List['Model'] = []
|
||||
self.values: Optional[BaseModel] = None
|
||||
|
||||
if "pk" in kwargs:
|
||||
kwargs[self.__pkname__] = kwargs.pop("pk")
|
||||
# breakpoint()
|
||||
kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()}
|
||||
self.values = self.__pydantic_model__(**kwargs)
|
||||
|
||||
@ -340,9 +359,9 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
if key != '__fields__' and key in self.__fields__:
|
||||
if key in self._orm_relationship_manager:
|
||||
parent_item = self._orm_relationship_manager.get(key)
|
||||
return parent_item
|
||||
relation_key = self.__class__.__name__.title() + '_' + key
|
||||
if self._orm_relationship_manager.contains(relation_key, self):
|
||||
return self._orm_relationship_manager.get(relation_key, self)
|
||||
|
||||
item = getattr(self.values, key, None)
|
||||
if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str):
|
||||
@ -393,11 +412,12 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
if column.name not in item:
|
||||
item[column.name] = row[column]
|
||||
|
||||
# breakpoint()
|
||||
return cls(**item)
|
||||
|
||||
@classmethod
|
||||
def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover
|
||||
return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value)
|
||||
def validate(cls, value: Any) -> 'BaseModel': # pragma no cover
|
||||
return cls.__pydantic_model__.validate(value=value)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls): # pragma no cover
|
||||
@ -405,7 +425,7 @@ class Model(list, metaclass=ModelMetaclass):
|
||||
|
||||
@classmethod
|
||||
def schema(cls, by_alias: bool = True): # pragma no cover
|
||||
return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias)
|
||||
return cls.__pydantic_model__.schema(by_alias=by_alias)
|
||||
|
||||
def is_conversion_to_json_needed(self, column_name: str) -> bool:
|
||||
return self.__model_fields__.get(column_name).__type__ == pydantic.Json
|
||||
|
||||
@ -1,44 +1,46 @@
|
||||
from typing import Dict, Union, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from orm.exceptions import RelationshipNotFound
|
||||
|
||||
|
||||
class Relationship:
|
||||
|
||||
def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'):
|
||||
self.fk_side = fk_side
|
||||
self.child = child
|
||||
self.parent = parent
|
||||
self.name = name
|
||||
if TYPE_CHECKING: # pragma no cover
|
||||
from orm.models import Model
|
||||
|
||||
|
||||
class RelationshipManager:
|
||||
|
||||
def __init__(self, model: 'Model'):
|
||||
self._orm_id: str = model._orm_id
|
||||
self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict()
|
||||
def __init__(self):
|
||||
self._relations = dict()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._relations
|
||||
def add_relation_type(self, relations_key, reverse_key):
|
||||
print(relations_key, reverse_key)
|
||||
if relations_key not in self._relations:
|
||||
self._relations[relations_key] = {'type': 'primary'}
|
||||
if reverse_key not in self._relations:
|
||||
self._relations[reverse_key] = {'type': 'reverse'}
|
||||
|
||||
def add_related(self, relation: Relationship):
|
||||
if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id:
|
||||
new_relation = Relationship(name=relation.parent.__class__.__name__.lower(),
|
||||
child=relation.parent,
|
||||
parent=relation.child,
|
||||
fk_side='parent')
|
||||
relation.child._orm_relationship_manager.add(new_relation)
|
||||
def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False):
|
||||
parent_id = parent._orm_id
|
||||
child_id = child._orm_id
|
||||
if virtual:
|
||||
child_name, parent_name = parent_name, child_name
|
||||
child_id, parent_id = parent_id, child_id
|
||||
child, parent = parent, child
|
||||
self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(
|
||||
child)
|
||||
self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent)
|
||||
|
||||
def add(self, relation: Relationship):
|
||||
if relation.name in self._relations:
|
||||
self._relations[relation.name].append(relation)
|
||||
else:
|
||||
self._relations[relation.name] = [relation]
|
||||
self.add_related(relation)
|
||||
def contains(self, relations_key: str, object: 'Model'):
|
||||
if relations_key in self._relations:
|
||||
return object._orm_id in self._relations[relations_key]
|
||||
return False
|
||||
|
||||
def get(self, name: str):
|
||||
for rel, relations in self._relations.items():
|
||||
if rel == name:
|
||||
if relations and relations[0].fk_side == 'parent':
|
||||
return relations[0].child
|
||||
return [rela.child for rela in relations]
|
||||
def get(self, relations_key: str, object: 'Model'):
|
||||
if relations_key in self._relations:
|
||||
if object._orm_id in self._relations[relations_key]:
|
||||
if self._relations[relations_key]['type'] == 'primary':
|
||||
return self._relations[relations_key][object._orm_id][0]
|
||||
return self._relations[relations_key][object._orm_id]
|
||||
|
||||
def __str__(self): # pragma no cover
|
||||
return ''.join(self._relations[rel].__str__() for rel in self._relations)
|
||||
|
||||
def __repr__(self): # pragma no cover
|
||||
return self.__str__()
|
||||
|
||||
@ -1,17 +1,13 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import databases
|
||||
import pydantic
|
||||
import sqlalchemy
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
import orm
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import orm
|
||||
from orm.exceptions import NoMatch, MultipleMatches
|
||||
from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
@ -229,3 +229,10 @@ async def test_get_exceptions():
|
||||
await Track.objects.create(album=fantasies, title="Test3", position=3)
|
||||
with pytest.raises(MultipleMatches):
|
||||
await Track.objects.select_related("album").get(album=fantasies)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_model_passed_as_fk():
|
||||
with pytest.raises(RelationshipInstanceError):
|
||||
org = await Organisation.objects.create(ident="ACME Ltd")
|
||||
await Track.objects.create(album=org, title="Test1", position=1)
|
||||
|
||||
77
tests/test_same_table_joins.py
Normal file
77
tests/test_same_table_joins.py
Normal file
@ -0,0 +1,77 @@
|
||||
import databases
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
|
||||
import orm
|
||||
from tests.settings import DATABASE_URL
|
||||
|
||||
database = databases.Database(DATABASE_URL, force_rollback=True)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
||||
class SchoolClass(orm.Model):
|
||||
__tablename__ = "schoolclasses"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
name = orm.String(length=100)
|
||||
|
||||
|
||||
class Category(orm.Model):
|
||||
__tablename__ = "cateogories"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
name = orm.String(length=100)
|
||||
|
||||
|
||||
class Student(orm.Model):
|
||||
__tablename__ = "students"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
name = orm.String(length=100)
|
||||
schoolclass = orm.ForeignKey(SchoolClass)
|
||||
category = orm.ForeignKey(Category, nullable=True)
|
||||
|
||||
|
||||
class Teacher(orm.Model):
|
||||
__tablename__ = "teachers"
|
||||
__metadata__ = metadata
|
||||
__database__ = database
|
||||
|
||||
id = orm.Integer(primary_key=True)
|
||||
name = orm.String(length=100)
|
||||
schoolclass = orm.ForeignKey(SchoolClass)
|
||||
category = orm.ForeignKey(Category, nullable=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
yield
|
||||
metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_multiple_instances_of_same_table_in_schema():
|
||||
async with database:
|
||||
class1 = await SchoolClass.objects.create(name="Math")
|
||||
category = await Category.objects.create(name="Foreign")
|
||||
category2 = await Category.objects.create(name="Domestic")
|
||||
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
|
||||
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
|
||||
|
||||
classes = await SchoolClass.objects.select_related(['teachers', 'students']).all()
|
||||
assert classes[0].name == 'Math'
|
||||
assert classes[0].students[0].name == 'Jane'
|
||||
|
||||
# related fields of main model are only populated by pk
|
||||
# but you can load them anytime
|
||||
assert classes[0].students[0].schoolclass.name is None
|
||||
await classes[0].students[0].schoolclass.load()
|
||||
assert classes[0].students[0].schoolclass.name == 'Math'
|
||||
Reference in New Issue
Block a user