WIP changes up to join redefinition pending - use fields instead of join_params

This commit is contained in:
collerek
2021-01-10 17:27:52 +01:00
parent 4071ff7d11
commit 8b67c83d0c
16 changed files with 151 additions and 84 deletions

View File

@ -267,7 +267,7 @@ class BaseField(FieldInfo):
return value
@classmethod
def set_self_reference_flag(cls):
def set_self_reference_flag(cls) -> None:
"""
Sets `self_reference` to True if field to and owner are same model.
:return: None
@ -301,3 +301,13 @@ class BaseField(FieldInfo):
:return: None
:rtype: None
"""
@classmethod
def get_related_name(cls) -> str:
"""
Returns name to use for reverse relation.
It's either set as `related_name` or by default it's owner model. get_name + 's'
:return: name of the related_name or default related name.
:rtype: str
"""
return "" # pragma: no cover

View File

@ -167,6 +167,8 @@ def ForeignKey( # noqa CFQ002
"""
owner = kwargs.pop("owner", None)
self_reference = kwargs.pop("self_reference", False)
if isinstance(to, ForwardRef):
__type__ = to if not nullable else Optional[to]
constraints: List = []
@ -196,6 +198,7 @@ def ForeignKey( # noqa CFQ002
onupdate=onupdate,
ondelete=ondelete,
owner=owner,
self_reference=self_reference,
)
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)

View File

@ -1,8 +1,7 @@
from typing import Any, ForwardRef, List, Optional, TYPE_CHECKING, Tuple, Type, Union
from pydantic.typing import evaluate_forwardref
import ormar
import ormar # noqa: I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
@ -71,6 +70,7 @@ def ManyToMany(
related_name = kwargs.pop("related_name", None)
nullable = kwargs.pop("nullable", True)
owner = kwargs.pop("owner", None)
self_reference = kwargs.pop("self_reference", False)
if isinstance(to, ForwardRef):
__type__ = to if not nullable else Optional[to]
@ -96,6 +96,7 @@ def ManyToMany(
default=None,
server_default=None,
owner=owner,
self_reference=self_reference,
)
return type("ManyToMany", (ManyToManyField, BaseField), namespace)

View File

@ -109,6 +109,7 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
virtual=True,
related_name=model_field.name,
owner=model_field.to,
self_reference=model_field.self_reference,
)
# register foreign keys on through model
adjust_through_many_to_many_model(model_field=model_field)
@ -119,12 +120,11 @@ def register_reverse_model_fields(model_field: Type["ForeignKeyField"]) -> None:
virtual=True,
related_name=model_field.name,
owner=model_field.to,
self_reference=model_field.self_reference,
)
def register_relation_in_alias_manager(
field: Type[ForeignKeyField], field_name: str
) -> None:
def register_relation_in_alias_manager(field: Type[ForeignKeyField]) -> None:
"""
Registers the relation (and reverse relation) in alias manager.
The m2m relations require registration of through model between
@ -136,8 +136,6 @@ def register_relation_in_alias_manager(
:param field: relation field
:type field: ForeignKey or ManyToManyField class
:param field_name: name of the relation key
:type field_name: str
"""
if issubclass(field, ManyToManyField):
if field.has_unresolved_forward_refs():

View File

@ -587,9 +587,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
populate_meta_sqlalchemy_table_if_required(new_model.Meta)
expand_reverse_relationships(new_model)
for field_name, field in new_model.Meta.model_fields.items():
register_relation_in_alias_manager(
field=field, field_name=field_name
)
register_relation_in_alias_manager(field=field)
if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname

View File

@ -1,7 +1,7 @@
from typing import Callable, Dict, List, TYPE_CHECKING, Tuple, Type
import ormar
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.mixins.relation_mixin import RelationMixin
@ -37,10 +37,7 @@ class PrefetchQueryMixin(RelationMixin):
:rtype: Tuple[Type[Model], str]
"""
if reverse:
field_name = (
parent_model.Meta.model_fields[related].related_name
or parent_model.get_name() + "s"
)
field_name = parent_model.Meta.model_fields[related].get_related_name()
field = target_model.Meta.model_fields[field_name]
if issubclass(field, ormar.fields.ManyToManyField):
field_name = field.default_target_field_name()
@ -79,7 +76,7 @@ class PrefetchQueryMixin(RelationMixin):
return column.get_alias() if use_raw else column.name
@classmethod
def get_related_field_name(cls, target_field: Type["BaseField"]) -> str:
def get_related_field_name(cls, target_field: Type["ForeignKeyField"]) -> str:
"""
Returns name of the relation field that should be used in prefetch query.
This field is later used to register relation in prefetch query,
@ -93,7 +90,7 @@ class PrefetchQueryMixin(RelationMixin):
if issubclass(target_field, ormar.fields.ManyToManyField):
return cls.get_name()
if target_field.virtual:
return target_field.related_name or cls.get_name() + "s"
return target_field.get_related_name()
return target_field.to.Meta.pkname
@classmethod

View File

@ -30,7 +30,7 @@ import sqlalchemy
from pydantic import BaseModel
import ormar # noqa I100
from ormar.exceptions import ModelError
from ormar.exceptions import ModelError, ModelPersistenceError
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.helpers import register_relation_in_alias_manager
@ -452,9 +452,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
field.evaluate_forward_ref(globalns=globalns, localns=localns)
field.set_self_reference_flag()
expand_reverse_relationship(model_field=field)
register_relation_in_alias_manager(
field=field, field_name=field_name,
)
register_relation_in_alias_manager(field=field)
update_column_definition(model=cls, field=field)
populate_meta_sqlalchemy_table_if_required(meta=cls.Meta)
super().update_forward_refs(**localns)
@ -731,9 +729,15 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if self.get_column_alias(k) in self.Meta.table.columns
}
for field in self._extract_db_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
relation_field = self.Meta.model_fields[field]
target_pk_name = relation_field.to.Meta.pkname
target_field = getattr(self, field)
self_fields[field] = getattr(target_field, target_pk_name, None)
if not relation_field.nullable and not self_fields[field]:
raise ModelPersistenceError(
f"You cannot save {relation_field.to.get_name()} "
f"model without pk set!"
)
return self_fields
def get_relation_model_id(self, target_field: Type["BaseField"]) -> Optional[int]:

View File

@ -135,7 +135,7 @@ class SqlJoin:
join_parameters.model_cls.Meta.model_fields[part], ManyToManyField
):
_fields = join_parameters.model_cls.Meta.model_fields
new_part = _fields[part].to.get_name()
new_part = _fields[part].default_target_field_name()
self._switch_many_to_many_order_columns(part, new_part)
if index > 0: # nested joins
fields, exclude_fields = SqlJoin.update_inclusions(
@ -435,11 +435,9 @@ class SqlJoin:
from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname
)
breakpoint()
elif join_params.prev_model.Meta.model_fields[part].virtual:
to_field = (
join_params.prev_model.Meta.model_fields[part].related_name
or join_params.prev_model.get_name() + "s"
)
to_field = join_params.prev_model.Meta.model_fields[part].get_related_name()
to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname

View File

@ -9,10 +9,12 @@ from typing import (
Tuple,
Type,
Union,
cast,
)
import ormar
from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset.clause import QueryClause
from ormar.queryset.query import Query
from ormar.queryset.utils import extract_models_to_dict_of_lists, translate_list_to_dict
@ -314,6 +316,7 @@ class PrefetchQuery:
for related in related_to_extract:
target_field = model.Meta.model_fields[related]
target_field = cast(Type[ForeignKeyField], target_field)
target_model = target_field.to.get_name()
model_id = model.get_relation_model_id(target_field=target_field)
@ -421,6 +424,7 @@ class PrefetchQuery:
fields = target_model.get_included(fields, related)
exclude_fields = target_model.get_excluded(exclude_fields, related)
target_field = target_model.Meta.model_fields[related]
target_field = cast(Type[ForeignKeyField], target_field)
reverse = False
if target_field.virtual or issubclass(target_field, ManyToManyField):
reverse = True
@ -585,7 +589,7 @@ class PrefetchQuery:
def _populate_rows( # noqa: CFQ002
self,
rows: List,
target_field: Type["BaseField"],
target_field: Type["ForeignKeyField"],
parent_model: Type["Model"],
table_prefix: str,
fields: Union[Set[Any], Dict[Any, Any], None],

View File

@ -197,7 +197,7 @@ class QuerySet:
limit_raw_sql=self.limit_sql_raw,
)
exp = qry.build_select_expression()
# print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
print("\n", exp.compile(compile_kwargs={"literal_binds": True}))
return exp
def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003

View File

@ -39,10 +39,9 @@ class QuerysetProxy(ormar.QuerySetProtocol):
self._queryset: Optional["QuerySet"] = qryset
self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner
self.related_field_name = (
self._owner.Meta.model_fields[self.relation.field_name].related_name
or self._owner.get_name() + "s"
)
self.related_field_name = self._owner.Meta.model_fields[
self.relation.field_name
].get_related_name()
self.related_field = self.relation.to.Meta.model_fields[self.related_field_name]
self.owner_pk_value = self._owner.pk
@ -108,8 +107,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
:type child: Model
"""
model_cls = self.relation.through
owner_column = self._owner.get_name()
child_column = child.get_name()
owner_column = self.related_field.default_target_field_name()
child_column = self.related_field.default_source_field_name()
kwargs = {owner_column: self._owner.pk, child_column: child.pk}
if child.pk is None:
raise ModelPersistenceError(
@ -129,8 +128,8 @@ class QuerysetProxy(ormar.QuerySetProtocol):
:type child: Model
"""
queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self._owner.get_name()
child_column = child.get_name()
owner_column = self.related_field.default_target_field_name()
child_column = self.related_field.default_source_field_name()
kwargs = {owner_column: self._owner, child_column: child}
link_instance = await queryset.filter(**kwargs).get() # type: ignore
await link_instance.delete()

View File

@ -164,8 +164,6 @@ class RelationsManager:
:param name: name of the relation
:type name: str
"""
relation_name = (
item.Meta.model_fields[name].related_name or item.get_name() + "s"
)
relation_name = item.Meta.model_fields[name].get_related_name()
item._orm.remove(name, parent)
parent._orm.remove(relation_name, item)

View File

@ -42,9 +42,8 @@ class RelationProxy(list):
if self._related_field_name:
return self._related_field_name
owner_field = self._owner.Meta.model_fields[self.field_name]
self._related_field_name = (
owner_field.related_name or self._owner.get_name() + "s"
)
self._related_field_name = owner_field.get_related_name()
return self._related_field_name
def __getattribute__(self, item: str) -> Any:

View File

@ -35,34 +35,34 @@ Game = ForwardRef("Game")
Child = ForwardRef("Child")
class ChildFriends(ormar.Model):
class ChildFriend(ormar.Model):
class Meta(ModelMeta):
metadata = metadata
database = db
# class Child(ormar.Model):
# class Meta(ModelMeta):
# metadata = metadata
# database = db
#
# id: int = ormar.Integer(primary_key=True)
# name: str = ormar.String(max_length=100)
# favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by")
# least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by")
# friends: List[Child] = ormar.ManyToMany(Child, through=ChildFriends)
#
#
# class Game(ormar.Model):
# class Meta(ModelMeta):
# metadata = metadata
# database = db
#
# id: int = ormar.Integer(primary_key=True)
# name: str = ormar.String(max_length=100)
class Child(ormar.Model):
class Meta(ModelMeta):
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
favourite_game: Game = ormar.ForeignKey(Game, related_name="liked_by")
least_favourite_game: Game = ormar.ForeignKey(Game, related_name="not_liked_by")
friends = ormar.ManyToMany(Child, through=ChildFriend, related_name="also_friends")
# Child.update_forward_refs()
class Game(ormar.Model):
class Meta(ModelMeta):
metadata = metadata
database = db
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
Child.update_forward_refs()
@pytest.fixture(autouse=True, scope="module")
@ -125,22 +125,56 @@ async def test_self_relation():
assert sam_check.employees[0].name == "Joe"
# @pytest.mark.asyncio
# async def test_other_forwardref_relation():
# checkers = await Game.objects.create(name="checkers")
# uno = await Game(name="Uno").save()
#
# await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save()
# await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save()
#
# billy_check = await Child.objects.select_related(
# ["favourite_game", "least_favourite_game"]
# ).get(name="Billy")
# assert billy_check.favourite_game == uno
# assert billy_check.least_favourite_game == checkers
#
# uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get(
# name="Uno"
# )
# assert uno_check.liked_by[0].name == "Billy"
# assert uno_check.not_liked_by[0].name == "Kate"
@pytest.mark.asyncio
async def test_other_forwardref_relation():
checkers = await Game.objects.create(name="checkers")
uno = await Game(name="Uno").save()
await Child(name="Billy", favourite_game=uno, least_favourite_game=checkers).save()
await Child(name="Kate", favourite_game=checkers, least_favourite_game=uno).save()
billy_check = await Child.objects.select_related(
["favourite_game", "least_favourite_game"]
).get(name="Billy")
assert billy_check.favourite_game == uno
assert billy_check.least_favourite_game == checkers
uno_check = await Game.objects.select_related(["liked_by", "not_liked_by"]).get(
name="Uno"
)
assert uno_check.liked_by[0].name == "Billy"
assert uno_check.not_liked_by[0].name == "Kate"
@pytest.mark.asyncio
async def test_m2m_self_forwardref_relation():
checkers = await Game.objects.create(name="checkers")
uno = await Game(name="Uno").save()
jenga = await Game(name="Jenga").save()
billy = await Child(
name="Billy", favourite_game=uno, least_favourite_game=checkers
).save()
kate = await Child(
name="Kate", favourite_game=checkers, least_favourite_game=uno
).save()
steve = await Child(
name="Steve", favourite_game=jenga, least_favourite_game=uno
).save()
await billy.friends.add(kate)
await billy.friends.add(steve)
await steve.friends.add(kate)
await steve.friends.add(billy)
billy_check = await Child.objects.select_related(
[
"friends",
"favourite_game",
"least_favourite_game",
"friends__favourite_game",
"friends__least_favourite_game",
]
).get(name="Billy")
assert len(billy_check.friends) == 2

View File

@ -80,6 +80,17 @@ async def cleanup():
await Author.objects.delete(each=True)
@pytest.mark.asyncio
async def test_not_saved_raises_error(cleanup):
async with database:
guido = await Author(first_name="Guido", last_name="Van Rossum").save()
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = Category(name="News")
with pytest.raises(ModelPersistenceError):
await post.categories.add(news)
@pytest.mark.asyncio
async def test_assigning_related_objects(cleanup):
async with database:

View File

@ -6,6 +6,7 @@ import sqlalchemy as sa
from sqlalchemy import create_engine
import ormar
from ormar.exceptions import ModelPersistenceError
from tests.settings import DATABASE_URL
metadata = sa.MetaData()
@ -61,3 +62,15 @@ async def test_model_relationship():
assert ws.id == 1
assert ws.topic == "Topic 2"
assert ws.category.name == "Foo"
@pytest.mark.asyncio
async def test_model_relationship_with_not_saved():
async with db:
async with db.transaction(force_rollback=True):
cat = Category(name="Foo", code=123)
with pytest.raises(ModelPersistenceError):
await Workshop(topic="Topic 1", category=cat).save()
with pytest.raises(ModelPersistenceError):
await Workshop.objects.create(topic="Topic 1", category=cat)