diff --git a/docs/releases.md b/docs/releases.md index 9d72cbd..cb652a7 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,13 @@ +# 0.10.11 + +## ✨ Features + +* + +## 🐛 Fixes + +* Fix creation of auto through model for m2m relation with ForwardRef [#226](https://github.com/collerek/ormar/issues/226) + # 0.10.10 ## ✨ Features diff --git a/ormar/models/helpers/sqlalchemy.py b/ormar/models/helpers/sqlalchemy.py index 9330d6c..113dd98 100644 --- a/ormar/models/helpers/sqlalchemy.py +++ b/ormar/models/helpers/sqlalchemy.py @@ -2,6 +2,7 @@ import logging from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union import sqlalchemy +from pydantic.typing import ForwardRef import ormar # noqa: I100, I202 from ormar.models.descriptors import RelationDescriptor @@ -203,7 +204,7 @@ def _is_through_model_not_set(field: "BaseField") -> bool: :return: result of the check :rtype: bool """ - return field.is_multi and not field.through + return field.is_multi and not field.through and not field.to.__class__ == ForwardRef def _is_db_field(field: "BaseField") -> bool: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 67ef315..a046552 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -447,6 +447,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass field = cast(ForeignKeyField, field) field.evaluate_forward_ref(globalns=globalns, localns=localns) field.set_self_reference_flag() + if field.is_multi and not field.through: + field = cast(ormar.ManyToManyField, field) + field.create_default_through_model() expand_reverse_relationship(model_field=field) register_relation_in_alias_manager(field=field) update_column_definition(model=cls, field=field) diff --git a/tests/test_deferred/test_forward_cross_refs.py b/tests/test_deferred/test_forward_cross_refs.py index 79dbf72..5805759 100644 --- a/tests/test_deferred/test_forward_cross_refs.py +++ b/tests/test_deferred/test_forward_cross_refs.py @@ -1,4 +1,5 @@ # type: ignore +from typing import List, Optional import databases import pytest @@ -17,10 +18,14 @@ engine = create_engine(DATABASE_URL) TeacherRef = ForwardRef("Teacher") +class BaseMeta(ormar.ModelMeta): + metadata = metadata + database = db + + class Student(ormar.Model): - class Meta(ModelMeta): - metadata = metadata - database = db + class Meta(BaseMeta): + pass id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) @@ -30,10 +35,8 @@ class Student(ormar.Model): class StudentTeacher(ormar.Model): - class Meta(ModelMeta): + class Meta(BaseMeta): tablename = "students_x_teachers" - metadata = metadata - database = db class Teacher(ormar.Model): @@ -50,6 +53,35 @@ class Teacher(ormar.Model): Student.update_forward_refs() +CityRef = ForwardRef("City") +CountryRef = ForwardRef("Country") + + +class Country(ormar.Model): + class Meta(BaseMeta): + tablename = "countries" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=128) + capital: Optional[CityRef] = ormar.ForeignKey( + CityRef, related_name="capital_city", nullable=True + ) + borders: Optional[List[CountryRef]] = ormar.ManyToMany(CountryRef) + + +class City(ormar.Model): + class Meta(BaseMeta): + tablename = "cities" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=128) + country: Country = ormar.ForeignKey( + Country, related_name="cities", skip_reverse=True + ) + + +Country.update_forward_refs() + @pytest.fixture(autouse=True, scope="module") def create_test_database(): @@ -114,3 +146,23 @@ async def test_double_relations(): assert len(quibble.own_students) == 2 assert quibble.own_students[1].name == "John" assert quibble.own_students[0].name == "Anna" + + +@pytest.mark.asyncio +async def test_auto_through_model(): + async with db: + async with db.transaction(force_rollback=True): + england = await Country(name="England").save() + france = await Country(name="France").save() + london = await City(name="London", country=england).save() + england.capital = london + await england.update() + await england.borders.add(france) + + check = await Country.objects.select_related(["capital", "borders"]).get( + name="England" + ) + assert check.name == "England" + assert check.capital.name == "London" + assert check.capital.country.pk == check.pk + assert check.borders[0] == france