all tests passes - creating dummy models if fk not nullable

This commit is contained in:
collerek
2020-08-07 15:21:37 +02:00
parent 62475a1949
commit 3929dd6d73
4 changed files with 30 additions and 12 deletions

BIN
.coverage

Binary file not shown.

View File

@ -1,14 +1,17 @@
import datetime import datetime
import decimal import decimal
from typing import Optional, List from typing import Optional, List, Type, TYPE_CHECKING
import orm
import sqlalchemy import sqlalchemy
from pydantic import Json from pydantic import Json
from pydantic.fields import ModelField from pydantic.fields import ModelField
import orm
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
class BaseField: class BaseField:
__type__ = None __type__ = None
@ -173,8 +176,16 @@ class Decimal(BaseField):
return sqlalchemy.DECIMAL(self.length, self.precision) return sqlalchemy.DECIMAL(self.length, self.precision)
def create_dummy_instance(fk: Type['Model'], pk: int = None):
init_dict = {fk.__pkname__: pk or -1}
init_dict = {**init_dict, **{k: create_dummy_instance(v.to)
for k, v in fk.__model_fields__.items()
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual}}
return fk(**init_dict)
class ForeignKey(BaseField): class ForeignKey(BaseField):
def __init__(self, to, related_name: str = None, nullable: bool = False, virtual: bool = False): def __init__(self, to, related_name: str = None, nullable: bool = True, virtual: bool = False):
super().__init__(nullable=nullable) super().__init__(nullable=nullable)
self.virtual = virtual self.virtual = virtual
self.related_name = related_name self.related_name = related_name
@ -206,14 +217,15 @@ class ForeignKey(BaseField):
elif isinstance(value, dict): elif isinstance(value, dict):
model = self.to(**value) model = self.to(**value)
else: else:
model = self.to(**{self.to.__pkname__: value}) model = create_dummy_instance(fk=self.to, pk=value)
child_model_name = self.related_name or child.__class__.__name__.lower() + 's' child_model_name = self.related_name or child.__class__.__name__.lower() + 's'
model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(), model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(),
child.__class__.__name__.lower(), child.__class__.__name__.lower(),
model, child, virtual=self.virtual) model, child, virtual=self.virtual)
if child_model_name not in model.__fields__: if child_model_name not in model.__fields__ \
and child.__class__.__name__.lower() not in model.__fields__:
model.__fields__[child_model_name] = ModelField(name=child_model_name, model.__fields__[child_model_name] = ModelField(name=child_model_name,
type_=Optional[child.__pydantic_model__], type_=Optional[child.__pydantic_model__],
model_config=child.__pydantic_model__.__config__, model_config=child.__pydantic_model__.__config__,

View File

@ -57,11 +57,17 @@ class QuerySet:
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}')
def build_select_expression(self): def build_select_expression(self):
tables = [self.table] # tables = [self.table]
columns = list(self.table.columns) columns = list(self.table.columns)
order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')]
select_from = self.table select_from = self.table
for key in self.model_cls.__model_fields__:
if not self.model_cls.__model_fields__[key].nullable \
and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \
and key not in self._select_related:
self._select_related.append(key)
for item in self._select_related: for item in self._select_related:
previous_alias = '' previous_alias = ''
from_table = self.table.name from_table = self.table.name
@ -86,7 +92,7 @@ class QuerySet:
on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key)
target_table = self.prefixed_table_name(alias, to_table) target_table = self.prefixed_table_name(alias, to_table)
select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause)
tables.append(model_cls.__table__) # tables.append(model_cls.__table__)
order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}'))
columns.extend(self.prefixed_columns(alias, model_cls.__table__)) columns.extend(self.prefixed_columns(alias, model_cls.__table__))

View File

@ -14,7 +14,7 @@ class Department(orm.Model):
__metadata__ = metadata __metadata__ = metadata
__database__ = database __database__ = database
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True, autoincrement=False)
name = orm.String(length=100) name = orm.String(length=100)
@ -25,7 +25,7 @@ class SchoolClass(orm.Model):
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
name = orm.String(length=100) name = orm.String(length=100)
department = orm.ForeignKey(Department) department = orm.ForeignKey(Department, nullable=False)
class Category(orm.Model): class Category(orm.Model):
@ -70,7 +70,7 @@ def create_test_database():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(): async def test_model_multiple_instances_of_same_table_in_schema():
async with database: async with database:
department = await Department.objects.create(name='Math Department') department = await Department.objects.create(id=1, name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic") category2 = await Category.objects.create(name="Domestic")
@ -91,7 +91,7 @@ async def test_model_multiple_instances_of_same_table_in_schema():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(): async def test_right_tables_join():
async with database: async with database:
department = await Department.objects.create(name='Math Department') department = await Department.objects.create(id=1, name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic") category2 = await Category.objects.create(name="Domestic")
@ -111,7 +111,7 @@ async def test_right_tables_join():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_reverse_related_objects(): async def test_multiple_reverse_related_objects():
async with database: async with database:
department = await Department.objects.create(name='Math Department') department = await Department.objects.create(id=1, name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department) class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign") category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic") category2 = await Category.objects.create(name="Domestic")