refactored reverse relation registration into the metaclass

This commit is contained in:
collerek
2020-08-11 15:27:10 +02:00
parent becb914e55
commit ace348e172
4 changed files with 65 additions and 55 deletions

BIN
.coverage

Binary file not shown.

View File

@ -51,7 +51,7 @@ class BaseField:
@property @property
def is_required(self) -> bool: def is_required(self) -> bool:
return ( return (
not self.nullable and not self.has_default and not self.is_auto_primary_key not self.nullable and not self.has_default and not self.is_auto_primary_key
) )
@property @property
@ -204,12 +204,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model":
class ForeignKey(BaseField): class ForeignKey(BaseField):
def __init__( def __init__(
self, self,
to: Type["Model"], to: Type["Model"],
name: str = None, name: str = None,
related_name: str = None, related_name: str = None,
nullable: bool = True, nullable: bool = True,
virtual: bool = False, virtual: bool = False,
) -> None: ) -> None:
super().__init__(nullable=nullable, name=name) super().__init__(nullable=nullable, name=name)
self.virtual = virtual self.virtual = virtual
@ -229,7 +229,7 @@ class ForeignKey(BaseField):
return to_column.get_column_type() return to_column.get_column_type()
def expand_relationship( def expand_relationship(
self, value: Any, child: "Model" self, value: Any, child: "Model"
) -> Union["Model", List["Model"]]: ) -> Union["Model", List["Model"]]:
if isinstance(value, orm.models.Model) and not isinstance(value, self.to): if isinstance(value, orm.models.Model) and not isinstance(value, self.to):
@ -261,7 +261,6 @@ class ForeignKey(BaseField):
return model return model
def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: def add_to_relationship_registry(self, model: "Model", child: "Model") -> None:
child_model_name = self.related_name or child.get_name() + "s"
model._orm_relationship_manager.add_relation( model._orm_relationship_manager.add_relation(
model.__class__.__name__.lower(), model.__class__.__name__.lower(),
child.__class__.__name__.lower(), child.__class__.__name__.lower(),
@ -269,23 +268,3 @@ class ForeignKey(BaseField):
child, child,
virtual=self.virtual, virtual=self.virtual,
) )
if (
child_model_name not in model.__fields__
and child.get_name() not in model.__fields__
):
self.register_reverse_model_fields(model, child, child_model_name)
@staticmethod
def register_reverse_model_fields(
model: "Model", child: "Model", child_model_name: str
) -> None:
model.__fields__[child_model_name] = ModelField(
name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__,
)
model.__model_fields__[child_model_name] = ForeignKey(
child.__class__, name=child_model_name, virtual=True
)

View File

@ -6,6 +6,7 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Callable, Dict, Set from typing import Callable, Dict, Set
import databases import databases
from pydantic.fields import ModelField
import orm.queryset as qry import orm.queryset as qry
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError
@ -41,8 +42,35 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) ->
) )
def expand_reverse_relationships(model: Type["Model"]):
for field_name, model_field in model.__model_fields__.items():
if isinstance(model_field, ForeignKey):
child_model_name = model_field.related_name or model.__name__.lower() + 's'
parent_model = model_field.to
child = model
if (
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
):
register_reverse_model_fields(parent_model, child, child_model_name)
def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str
) -> None:
model.__fields__[child_model_name] = ModelField(
name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__,
)
model.__model_fields__[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True
)
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
pkname: Optional[str] = None pkname: Optional[str] = None
columns: List[sqlalchemy.Column] = [] columns: List[sqlalchemy.Column] = []
@ -100,14 +128,16 @@ class ModelMetaclass(type):
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
attrs["__model_fields__"] = model_fields
attrs["__model_fields__"] = model_fields
attrs["_orm_relationship_manager"] = relationship_manager attrs["_orm_relationship_manager"] = relationship_manager
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
expand_reverse_relationships(new_model)
return new_model return new_model
@ -168,9 +198,9 @@ class FakePydantic(list, metaclass=ModelMetaclass):
item = getattr(self.values, key, None) item = getattr(self.values, key, None)
if ( if (
item is not None item is not None
and self._is_conversion_to_json_needed(key) and self._is_conversion_to_json_needed(key)
and isinstance(item, str) and isinstance(item, str)
): ):
try: try:
item = json.loads(item) item = json.loads(item)
@ -186,7 +216,7 @@ class FakePydantic(list, metaclass=ModelMetaclass):
if self.__class__ != other.__class__: # pragma no cover if self.__class__ != other.__class__: # pragma no cover
return False return False
return self._orm_id == other._orm_id or ( return self._orm_id == other._orm_id or (
self.values is not None and other.values is not None and self.pk == other.pk self.values is not None and other.values is not None and self.pk == other.pk
) )
def __repr__(self) -> str: # pragma no cover def __repr__(self) -> str: # pragma no cover
@ -242,7 +272,7 @@ class FakePydantic(list, metaclass=ModelMetaclass):
related_names = set() related_names = set()
for name, field in cls.__fields__.items(): for name, field in cls.__fields__.items():
if inspect.isclass(field.type_) and issubclass( if inspect.isclass(field.type_) and issubclass(
field.type_, pydantic.BaseModel field.type_, pydantic.BaseModel
): ):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -274,7 +304,7 @@ class FakePydantic(list, metaclass=ModelMetaclass):
for field in one.__model_fields__.keys(): for field in one.__model_fields__.keys():
# print(field, one.dict(), other.dict()) # print(field, one.dict(), other.dict())
if isinstance(getattr(one, field), list) and not isinstance( if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), Model getattr(one, field), Model
): ):
setattr(other, field, getattr(one, field) + getattr(other, field)) setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), Model): elif isinstance(getattr(one, field), Model):
@ -296,10 +326,10 @@ class Model(FakePydantic):
@classmethod @classmethod
def from_row( def from_row(
cls, cls,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
select_related: List = None, select_related: List = None,
previous_table: str = None, previous_table: str = None,
) -> "Model": ) -> "Model":
item = {} item = {}
@ -357,8 +387,8 @@ class Model(FakePydantic):
self_fields.pop(self.__pkname__) self_fields.pop(self.__pkname__)
expr = ( expr = (
self.__table__.update() self.__table__.update()
.values(**self_fields) .values(**self_fields)
.where(self.pk_column == getattr(self, self.__pkname__)) .where(self.pk_column == getattr(self, self.__pkname__))
) )
result = await self.__database__.execute(expr) result = await self.__database__.execute(expr)
return result return result

View File

@ -1,3 +1,5 @@
import asyncio
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
@ -59,16 +61,17 @@ class Teacher(orm.Model):
category = orm.ForeignKey(Category, nullable=True) category = orm.ForeignKey(Category, nullable=True)
@pytest.fixture(scope='module')
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine) metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture()
async def init_relation():
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
@ -77,13 +80,11 @@ async def init_relation():
await Student.objects.create(name="Jack", category=category2, schoolclass=class1) await Student.objects.create(name="Jack", category=category2, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield yield
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine) metadata.drop_all(engine)
metadata.create_all(engine)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(init_relation): async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students"]
@ -102,7 +103,7 @@ async def test_model_multiple_instances_of_same_table_in_schema(init_relation):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(init_relation): async def test_right_tables_join():
async with database: async with database:
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students"]
@ -115,7 +116,7 @@ async def test_right_tables_join(init_relation):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_reverse_related_objects(init_relation): async def test_multiple_reverse_related_objects():
async with database: async with database:
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students"]