WIP changes up to join redefinition pending - use fields instead of join_params

This commit is contained in:
collerek
2021-01-10 17:27:52 +01:00
parent 4071ff7d11
commit 8b67c83d0c
16 changed files with 151 additions and 84 deletions

View File

@ -267,7 +267,7 @@ class BaseField(FieldInfo):
return value return value
@classmethod @classmethod
def set_self_reference_flag(cls): def set_self_reference_flag(cls) -> None:
""" """
Sets `self_reference` to True if field to and owner are same model. Sets `self_reference` to True if field to and owner are same model.
:return: None :return: None
@ -301,3 +301,13 @@ class BaseField(FieldInfo):
:return: None :return: None
:rtype: None :rtype: None
""" """
@classmethod
def get_related_name(cls) -> str:
"""
Returns name to use for reverse relation.
It's either set as `related_name` or by default it's owner model. get_name + 's'
:return: name of the related_name or default related name.
:rtype: str
"""
return "" # pragma: no cover

View File

@ -167,6 +167,8 @@ def ForeignKey( # noqa CFQ002
""" """
owner = kwargs.pop("owner", None) owner = kwargs.pop("owner", None)
self_reference = kwargs.pop("self_reference", False)
if isinstance(to, ForwardRef): if isinstance(to, ForwardRef):
__type__ = to if not nullable else Optional[to] __type__ = to if not nullable else Optional[to]
constraints: List = [] constraints: List = []
@ -196,6 +198,7 @@ def ForeignKey( # noqa CFQ002
onupdate=onupdate, onupdate=onupdate,
ondelete=ondelete, ondelete=ondelete,
owner=owner, owner=owner,
self_reference=self_reference,
) )
return type("ForeignKey", (ForeignKeyField, BaseField), namespace) return type("ForeignKey", (ForeignKeyField, BaseField), namespace)

View File

@ -1,8 +1,7 @@
from typing import Any, ForwardRef, List, Optional, TYPE_CHECKING, Tuple, Type, Union from typing import Any, ForwardRef, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from pydantic.typing import evaluate_forwardref from pydantic.typing import evaluate_forwardref
import ormar # noqa: I100
import ormar
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
@ -71,6 +70,7 @@ def ManyToMany(
related_name = kwargs.pop("related_name", None) related_name = kwargs.pop("related_name", None)
nullable = kwargs.pop("nullable", True) nullable = kwargs.pop("nullable", True)
owner = kwargs.pop("owner", None) owner = kwargs.pop("owner", None)
self_reference = kwargs.pop("self_reference", False)
if isinstance(to, ForwardRef): if isinstance(to, ForwardRef):
__type__ = to if not nullable else Optional[to] __type__ = to if not nullable else Optional[to]
@ -96,6 +96,7 @@ def ManyToMany(
default=None, default=None,
server_default=None, server_default=None,
owner=owner, owner=owner,
self_reference=self_reference,
) )
return type("ManyToMany", (ManyToManyField, BaseField), namespace) return type("ManyToMany", (ManyToManyField, BaseField), namespace)

View File

@ -109,6 +109,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
virtual=True, virtual=True,
related_name=model_field.name, related_name=model_field.name,
owner=model_field.to, owner=model_field.to,
self_reference=model_field.self_reference,
) )
# register foreign keys on through model # register foreign keys on through model
adjust_through_many_to_many_model(model_field=model_field) adjust_through_many_to_many_model(model_field=model_field)
@ -119,12 +120,11 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
virtual=True, virtual=True,
related_name=model_field.name, related_name=model_field.name,
owner=model_field.to, owner=model_field.to,
self_reference=model_field.self_reference,
) )
def register_relation_in_alias_manager( def register_relation_in_alias_manager(field: Type[ForeignKeyField]) -> None:
field: Type[ForeignKeyField], field_name: str
) -> None:
""" """
Registers the relation (and reverse relation) in alias manager. Registers the relation (and reverse relation) in alias manager.
The m2m relations require registration of through model between The m2m relations require registration of through model between
@ -136,8 +136,6 @@ def register_relation_in_alias_manager(
:param field: relation field :param field: relation field
:type field: ForeignKey or ManyToManyField class :type field: ForeignKey or ManyToManyField class
:param field_name: name of the relation key
:type field_name: str
""" """
if issubclass(field, ManyToManyField): if issubclass(field, ManyToManyField):
if field.has_unresolved_forward_refs(): if field.has_unresolved_forward_refs():

View File

@ -587,9 +587,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
populate_meta_sqlalchemy_table_if_required(new_model.Meta) populate_meta_sqlalchemy_table_if_required(new_model.Meta)
expand_reverse_relationships(new_model) expand_reverse_relationships(new_model)
for field_name, field in new_model.Meta.model_fields.items(): for field_name, field in new_model.Meta.model_fields.items():
register_relation_in_alias_manager( register_relation_in_alias_manager(field=field)
field=field, field_name=field_name
)
if new_model.Meta.pkname not in attrs["__annotations__"]: if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname field_name = new_model.Meta.pkname

View File

@ -1,7 +1,7 @@
from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type
import ormar import ormar
from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.mixins.relation_mixin import RelationMixin from ormar.models.mixins.relation_mixin import RelationMixin
@ -37,10 +37,7 @@ class PrefetchQueryMixin(RelationMixin):
:rtype: Tuple[Type[Model], str] :rtype: Tuple[Type[Model], str]
""" """
if reverse: if reverse:
field_name = ( field_name = parent_model.Meta.model_fields[related].get_related_name()
parent_model.Meta.model_fields[related].related_name
or parent_model.get_name() + "s"
)
field = target_model.Meta.model_fields[field_name] field = target_model.Meta.model_fields[field_name]
if issubclass(field, ormar.fields.ManyToManyField): if issubclass(field, ormar.fields.ManyToManyField):
field_name = field.default_target_field_name() field_name = field.default_target_field_name()
@ -79,7 +76,7 @@ class PrefetchQueryMixin(RelationMixin):
return column.get_alias() if use_raw else column.name return column.get_alias() if use_raw else column.name
@classmethod @classmethod
def get_related_field_name(cls, target_field: Type["BaseField"]) -> str: def get_related_field_name(cls, target_field: Type["ForeignKeyField"]) -> str:
""" """
Returns name of the relation field that should be used in prefetch query. Returns name of the relation field that should be used in prefetch query.
This field is later used to register relation in prefetch query, This field is later used to register relation in prefetch query,
@ -93,7 +90,7 @@ class PrefetchQueryMixin(RelationMixin):
if issubclass(target_field, ormar.fields.ManyToManyField): if issubclass(target_field, ormar.fields.ManyToManyField):
return cls.get_name() return cls.get_name()
if target_field.virtual: if target_field.virtual:
return target_field.related_name or cls.get_name() + "s" return target_field.get_related_name()
return target_field.to.Meta.pkname return target_field.to.Meta.pkname
@classmethod @classmethod

View File

@ -30,7 +30,7 @@ import sqlalchemy
from pydantic import BaseModel from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import ModelError from ormar.exceptions import ModelError, ModelPersistenceError
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.helpers import register_relation_in_alias_manager from ormar.models.helpers import register_relation_in_alias_manager
@ -452,9 +452,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
field.evaluate_forward_ref(globalns=globalns, localns=localns) field.evaluate_forward_ref(globalns=globalns, localns=localns)
field.set_self_reference_flag() field.set_self_reference_flag()
expand_reverse_relationship(model_field=field) expand_reverse_relationship(model_field=field)
register_relation_in_alias_manager( register_relation_in_alias_manager(field=field)
field=field, field_name=field_name,
)
update_column_definition(model=cls, field=field) update_column_definition(model=cls, field=field)
populate_meta_sqlalchemy_table_if_required(meta=cls.Meta) populate_meta_sqlalchemy_table_if_required(meta=cls.Meta)
super().update_forward_refs(**localns) super().update_forward_refs(**localns)
@ -731,9 +729,15 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if self.get_column_alias(k) in self.Meta.table.columns if self.get_column_alias(k) in self.Meta.table.columns
} }
for field in self._extract_db_related_names(): for field in self._extract_db_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname relation_field = self.Meta.model_fields[field]
target_pk_name = relation_field.to.Meta.pkname
target_field = getattr(self, field) target_field = getattr(self, field)
self_fields[field] = getattr(target_field, target_pk_name, None) self_fields[field] = getattr(target_field, target_pk_name, None)
if not relation_field.nullable and not self_fields[field]:
raise ModelPersistenceError(
f"You cannot save {relation_field.to.get_name()} "
f"model without pk set!"
)
return self_fields return self_fields
def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]: def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]:

View File

@ -135,7 +135,7 @@ class SqlJoin:
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
): ):
_fields = join_parameters.model_cls.Meta.model_fields _fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].to.get_name() new_part = _fields[part].default_target_field_name()
self._switch_many_to_many_order_columns(part, new_part) self._switch_many_to_many_order_columns(part, new_part)
if index > 0: # nested joins if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions( fields, exclude_fields = SqlJoin.update_inclusions(
@ -435,11 +435,9 @@ class SqlJoin:
from_key = join_params.prev_model.get_column_alias( from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname join_params.prev_model.Meta.pkname
) )
breakpoint()
elif join_params.prev_model.Meta.model_fields[part].virtual: elif join_params.prev_model.Meta.model_fields[part].virtual:
to_field = ( to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
join_params.prev_model.Meta.model_fields[part].related_name
or join_params.prev_model.get_name() + "s"
)
to_key = model_cls.get_column_alias(to_field) to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias( from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname join_params.prev_model.Meta.pkname

View File

@ -9,10 +9,12 @@ from typing import (
Tuple, Tuple,
Type, Type,
Union, Union,
cast,
) )
import ormar import ormar
from ormar.fields import BaseField, ManyToManyField from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset.clause import QueryClause from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query from ormar.queryset.query import Query
from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
@ -314,6 +316,7 @@ class PrefetchQuery:
for related in related_to_extract: for related in related_to_extract:
target_field = model.Meta.model_fields[related] target_field = model.Meta.model_fields[related]
target_field = cast(Type[ForeignKeyField], target_field)
target_model = target_field.to.get_name() target_model = target_field.to.get_name()
model_id = model.get_relation_model_id(target_field=target_field) model_id = model.get_relation_model_id(target_field=target_field)
@ -421,6 +424,7 @@ class PrefetchQuery:
fields = target_model.get_included(fields, related) fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related) exclude_fields = target_model.get_excluded(exclude_fields, related)
target_field = target_model.Meta.model_fields[related] target_field = target_model.Meta.model_fields[related]
target_field = cast(Type[ForeignKeyField], target_field)
reverse = False reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField): if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True reverse = True
@ -585,7 +589,7 @@ class PrefetchQuery:
def _populate_rows( # noqa: CFQ002 def _populate_rows( # noqa: CFQ002
self, self,
rows: List, rows: List,
target_field: Type["BaseField"], target_field: Type["ForeignKeyField"],
parent_model: Type["Model"], parent_model: Type["Model"],
table_prefix: str, table_prefix: str,
fields: Union[Set[Any], Dict[Any, Any], None], fields: Union[Set[Any], Dict[Any, Any], None],

View File

@ -197,7 +197,7 @@ class QuerySet:
limit_raw_sql=self.limit_sql_raw, limit_raw_sql=self.limit_sql_raw,
) )
exp = qry.build_select_expression() exp = qry.build_select_expression()
# print("\n", exp.compile(compile_kwargs={"literal_binds": True})) print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
return exp return exp
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003

View File

@ -39,10 +39,9 @@ class QuerysetProxy(ormar.QuerySetProtocol):
self._queryset: Optional["QuerySet"] = qryset self._queryset: Optional["QuerySet"] = qryset
self.type_: "RelationType" = type_ self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.related_field_name = ( self.related_field_name = self._owner.Meta.model_fields[
self._owner.Meta.model_fields[self.relation.field_name].related_name self.relation.field_name
or self._owner.get_name() + "s" ].get_related_name()
)
self.related_field = self.relation.to.Meta.model_fields[self.related_field_name] self.related_field = self.relation.to.Meta.model_fields[self.related_field_name]
self.owner_pk_value = self._owner.pk self.owner_pk_value = self._owner.pk
@ -108,8 +107,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
:type child: Model :type child: Model
""" """
model_cls = self.relation.through model_cls = self.relation.through
owner_column = self._owner.get_name() owner_column = self.related_field.default_target_field_name()
child_column = child.get_name() child_column = self.related_field.default_source_field_name()
kwargs = {owner_column: self._owner.pk, child_column: child.pk} kwargs = {owner_column: self._owner.pk, child_column: child.pk}
if child.pk is None: if child.pk is None:
raise ModelPersistenceError( raise ModelPersistenceError(
@ -129,8 +128,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
:type child: Model :type child: Model
""" """
queryset = ormar.QuerySet(model_cls=self.relation.through) queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self._owner.get_name() owner_column = self.related_field.default_target_field_name()
child_column = child.get_name() child_column = self.related_field.default_source_field_name()
kwargs = {owner_column: self._owner, child_column: child} kwargs = {owner_column: self._owner, child_column: child}
link_instance = await queryset.filter(**kwargs).get() # type: ignore link_instance = await queryset.filter(**kwargs).get() # type: ignore
await link_instance.delete() await link_instance.delete()

View File

@ -164,8 +164,6 @@ class RelationsManager:
:param name: name of the relation :param name: name of the relation
:type name: str :type name: str
""" """
relation_name = ( relation_name = item.Meta.model_fields[name].get_related_name()
item.Meta.model_fields[name].related_name or item.get_name() + "s"
)
item._orm.remove(name, parent) item._orm.remove(name, parent)
parent._orm.remove(relation_name, item) parent._orm.remove(relation_name, item)

View File

@ -42,9 +42,8 @@ class RelationProxy(list):
if self._related_field_name: if self._related_field_name:
return self._related_field_name return self._related_field_name
owner_field = self._owner.Meta.model_fields[self.field_name] owner_field = self._owner.Meta.model_fields[self.field_name]
self._related_field_name = ( self._related_field_name = owner_field.get_related_name()
owner_field.related_name or self._owner.get_name() + "s"
)
return self._related_field_name return self._related_field_name
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:

View File

@ -35,34 +35,34 @@ Game = ForwardRef("Game")
Child = ForwardRef("Child") Child = ForwardRef("Child")
class ChildFriends(ormar.Model): class ChildFriend(ormar.Model):
class Meta(ModelMeta): class Meta(ModelMeta):
metadata = metadata metadata = metadata
database = db database = db
# class Child(ormar.Model): class Child(ormar.Model):
# class Meta(ModelMeta): class Meta(ModelMeta):
# metadata = metadata metadata = metadata
# database = db database = db
#
# id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
# name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
# favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by") favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by")
# least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by") least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by")
# friends: List[Child] = ormar.ManyToMany(Child, through=ChildFriends) friends = ormar.ManyToMany(Child, through=ChildFriend, related_name="also_friends")
#
#
# class Game(ormar.Model):
# class Meta(ModelMeta):
# metadata = metadata
# database = db
#
# id: int = ormar.Integer(primary_key=True)
# name: str = ormar.String(max_length=100)
# Child.update_forward_refs() class Game(ormar.Model):
class Meta(ModelMeta):
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
Child.update_forward_refs()
@pytest.fixture(autouse=True, scope="module") @pytest.fixture(autouse=True, scope="module")
@ -125,22 +125,56 @@ async def test_self_relation():
assert sam_check.employees[0].name == "Joe" assert sam_check.employees[0].name == "Joe"
# @pytest.mark.asyncio @pytest.mark.asyncio
# async def test_other_forwardref_relation(): async def test_other_forwardref_relation():
# checkers = await Game.objects.create(name="checkers") checkers = await Game.objects.create(name="checkers")
# uno = await Game(name="Uno").save() uno = await Game(name="Uno").save()
#
# await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save() await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save()
# await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save() await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save()
#
# billy_check = await Child.objects.select_related( billy_check = await Child.objects.select_related(
# ["favourite_game", "least_favourite_game"] ["favourite_game", "least_favourite_game"]
# ).get(name="Billy") ).get(name="Billy")
# assert billy_check.favourite_game == uno assert billy_check.favourite_game == uno
# assert billy_check.least_favourite_game == checkers assert billy_check.least_favourite_game == checkers
#
# uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get( uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get(
# name="Uno" name="Uno"
# ) )
# assert uno_check.liked_by[0].name == "Billy" assert uno_check.liked_by[0].name == "Billy"
# assert uno_check.not_liked_by[0].name == "Kate" assert uno_check.not_liked_by[0].name == "Kate"
@pytest.mark.asyncio
async def test_m2m_self_forwardref_relation():
checkers = await Game.objects.create(name="checkers")
uno = await Game(name="Uno").save()
jenga = await Game(name="Jenga").save()
billy = await Child(
name="Billy", favourite_game=uno, least_favourite_game=checkers
).save()
kate = await Child(
name="Kate", favourite_game=checkers, least_favourite_game=uno
).save()
steve = await Child(
name="Steve", favourite_game=jenga, least_favourite_game=uno
).save()
await billy.friends.add(kate)
await billy.friends.add(steve)
await steve.friends.add(kate)
await steve.friends.add(billy)
billy_check = await Child.objects.select_related(
[
"friends",
"favourite_game",
"least_favourite_game",
"friends__favourite_game",
"friends__least_favourite_game",
]
).get(name="Billy")
assert len(billy_check.friends) == 2

View File

@ -80,6 +80,17 @@ async def cleanup():
await Author.objects.delete(each=True) await Author.objects.delete(each=True)
@pytest.mark.asyncio
async def test_not_saved_raises_error(cleanup):
async with database:
guido = await Author(first_name="Guido", last_name="Van Rossum").save()
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = Category(name="News")
with pytest.raises(ModelPersistenceError):
await post.categories.add(news)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_assigning_related_objects(cleanup): async def test_assigning_related_objects(cleanup):
async with database: async with database:

View File

@ -6,6 +6,7 @@ import sqlalchemy as sa
from sqlalchemy import create_engine from sqlalchemy import create_engine
import ormar import ormar
from ormar.exceptions import ModelPersistenceError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
metadata = sa.MetaData() metadata = sa.MetaData()
@ -61,3 +62,15 @@ async def test_model_relationship():
assert ws.id == 1 assert ws.id == 1
assert ws.topic == "Topic 2" assert ws.topic == "Topic 2"
assert ws.category.name == "Foo" assert ws.category.name == "Foo"
@pytest.mark.asyncio
async def test_model_relationship_with_not_saved():
async with db:
async with db.transaction(force_rollback=True):
cat = Category(name="Foo", code=123)
with pytest.raises(ModelPersistenceError):
await Workshop(topic="Topic 1", category=cat).save()
with pytest.raises(ModelPersistenceError):
await Workshop.objects.create(topic="Topic 1", category=cat)