diff --git a/.coverage b/.coverage index 955ea24..807412b 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/fields.py b/orm/fields.py index 2723b82..69ba85f 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,14 +1,17 @@ import datetime import decimal -from typing import Optional, List +from typing import Optional, List, Type, TYPE_CHECKING -import orm import sqlalchemy from pydantic import Json from pydantic.fields import ModelField +import orm from orm.exceptions import ModelDefinitionError, RelationshipInstanceError +if TYPE_CHECKING: # pragma no cover + from orm.models import Model + class BaseField: __type__ = None @@ -173,8 +176,16 @@ class Decimal(BaseField): return sqlalchemy.DECIMAL(self.length, self.precision) +def create_dummy_instance(fk: Type['Model'], pk: int = None): + init_dict = {fk.__pkname__: pk or -1} + init_dict = {**init_dict, **{k: create_dummy_instance(v.to) + for k, v in fk.__model_fields__.items() + if isinstance(v, ForeignKey) and not v.nullable and not v.virtual}} + return fk(**init_dict) + + class ForeignKey(BaseField): - def __init__(self, to, related_name: str = None, nullable: bool = False, virtual: bool = False): + def __init__(self, to, related_name: str = None, nullable: bool = True, virtual: bool = False): super().__init__(nullable=nullable) self.virtual = virtual self.related_name = related_name @@ -206,14 +217,15 @@ class ForeignKey(BaseField): elif isinstance(value, dict): model = self.to(**value) else: - model = self.to(**{self.to.__pkname__: value}) + model = create_dummy_instance(fk=self.to, pk=value) child_model_name = self.related_name or child.__class__.__name__.lower() + 's' model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(), child.__class__.__name__.lower(), model, child, virtual=self.virtual) - if child_model_name not in model.__fields__: + if child_model_name not in model.__fields__ \ + and child.__class__.__name__.lower() not in model.__fields__: model.__fields__[child_model_name] = ModelField(name=child_model_name, type_=Optional[child.__pydantic_model__], model_config=child.__pydantic_model__.__config__, diff --git a/orm/queryset.py b/orm/queryset.py index 1f01d7d..2eba123 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -57,11 +57,17 @@ class QuerySet: f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') def build_select_expression(self): - tables = [self.table] + # tables = [self.table] columns = list(self.table.columns) order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] select_from = self.table + for key in self.model_cls.__model_fields__: + if not self.model_cls.__model_fields__[key].nullable \ + and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \ + and key not in self._select_related: + self._select_related.append(key) + for item in self._select_related: previous_alias = '' from_table = self.table.name @@ -86,7 +92,7 @@ class QuerySet: on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) target_table = self.prefixed_table_name(alias, to_table) select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) - tables.append(model_cls.__table__) + # tables.append(model_cls.__table__) order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) columns.extend(self.prefixed_columns(alias, model_cls.__table__)) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 30b34d5..f98e01b 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -14,7 +14,7 @@ class Department(orm.Model): __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) + id = orm.Integer(primary_key=True, autoincrement=False) name = orm.String(length=100) @@ -25,7 +25,7 @@ class SchoolClass(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) - department = orm.ForeignKey(Department) + department = orm.ForeignKey(Department, nullable=False) class Category(orm.Model): @@ -70,7 +70,7 @@ def create_test_database(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: - department = await Department.objects.create(name='Math Department') + department = await Department.objects.create(id=1, name='Math Department') class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") @@ -91,7 +91,7 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio async def test_right_tables_join(): async with database: - department = await Department.objects.create(name='Math Department') + department = await Department.objects.create(id=1, name='Math Department') class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") @@ -111,7 +111,7 @@ async def test_right_tables_join(): @pytest.mark.asyncio async def test_multiple_reverse_related_objects(): async with database: - department = await Department.objects.create(name='Math Department') + department = await Department.objects.create(id=1, name='Math Department') class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic")