refactored reverse relation registration into the metaclass

This commit is contained in:
collerek
2020-08-11 15:27:10 +02:00
parent becb914e55
commit ace348e172
4 changed files with 65 additions and 55 deletions

BIN
.coverage

Binary file not shown.

View File

@ -261,7 +261,6 @@ class ForeignKey(BaseField):
return model return model
def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: def add_to_relationship_registry(self, model: "Model", child: "Model") -> None:
child_model_name = self.related_name or child.get_name() + "s"
model._orm_relationship_manager.add_relation( model._orm_relationship_manager.add_relation(
model.__class__.__name__.lower(), model.__class__.__name__.lower(),
child.__class__.__name__.lower(), child.__class__.__name__.lower(),
@ -269,23 +268,3 @@ class ForeignKey(BaseField):
child, child,
virtual=self.virtual, virtual=self.virtual,
) )
if (
child_model_name not in model.__fields__
and child.get_name() not in model.__fields__
):
self.register_reverse_model_fields(model, child, child_model_name)
@staticmethod
def register_reverse_model_fields(
model: "Model", child: "Model", child_model_name: str
) -> None:
model.__fields__[child_model_name] = ModelField(
name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__,
)
model.__model_fields__[child_model_name] = ForeignKey(
child.__class__, name=child_model_name, virtual=True
)

View File

@ -6,6 +6,7 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from typing import Callable, Dict, Set from typing import Callable, Dict, Set
import databases import databases
from pydantic.fields import ModelField
import orm.queryset as qry import orm.queryset as qry
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError
@ -41,6 +42,33 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) ->
) )
def expand_reverse_relationships(model: Type["Model"]):
for field_name, model_field in model.__model_fields__.items():
if isinstance(model_field, ForeignKey):
child_model_name = model_field.related_name or model.__name__.lower() + 's'
parent_model = model_field.to
child = model
if (
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
):
register_reverse_model_fields(parent_model, child, child_model_name)
def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str
) -> None:
model.__fields__[child_model_name] = ModelField(
name=child_model_name,
type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__,
class_validators=child.__pydantic_model__.__validators__,
)
model.__model_fields__[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True
)
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
@ -100,14 +128,16 @@ class ModelMetaclass(type):
attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__)
attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__)
attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__)
attrs["__model_fields__"] = model_fields
attrs["__model_fields__"] = model_fields
attrs["_orm_relationship_manager"] = relationship_manager attrs["_orm_relationship_manager"] = relationship_manager
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
expand_reverse_relationships(new_model)
return new_model return new_model

View File

@ -1,3 +1,5 @@
import asyncio
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
@ -59,16 +61,17 @@ class Teacher(orm.Model):
category = orm.ForeignKey(Category, nullable=True) category = orm.ForeignKey(Category, nullable=True)
@pytest.fixture(scope='module')
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
def create_test_database(): async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL) engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine) metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture()
async def init_relation():
department = await Department.objects.create(id=1, name="Math Department") department = await Department.objects.create(id=1, name="Math Department")
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
@ -77,13 +80,11 @@ async def init_relation():
await Student.objects.create(name="Jack", category=category2, schoolclass=class1) await Student.objects.create(name="Jack", category=category2, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield yield
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine) metadata.drop_all(engine)
metadata.create_all(engine)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(init_relation): async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students"]
@ -102,7 +103,7 @@ async def test_model_multiple_instances_of_same_table_in_schema(init_relation):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(init_relation): async def test_right_tables_join():
async with database: async with database:
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students"]
@ -115,7 +116,7 @@ async def test_right_tables_join(init_relation):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_reverse_related_objects(init_relation): async def test_multiple_reverse_related_objects():
async with database: async with database:
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students"]