fix confligs after merge from master

This commit is contained in:
collerek
2020-12-15 14:43:15 +01:00
11 changed files with 191 additions and 34 deletions

View File

@ -1,3 +1,7 @@
# 0.7.5
* Fix for wrong relation column name in many_to_many relation joins (fix [#73][#73])
# 0.7.4 # 0.7.4
* Allow multiple relations to the same related model/table. * Allow multiple relations to the same related model/table.

View File

@ -44,7 +44,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.7.4" __version__ = "0.7.5"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -322,7 +322,6 @@ class ForeignKeyField(BaseField):
""" """
if value is None: if value is None:
return None if not cls.virtual else [] return None if not cls.virtual else []
constructors = { constructors = {
f"{cls.to.__name__}": cls._register_existing_model, f"{cls.to.__name__}": cls._register_existing_model,
"dict": cls._construct_model_from_dict, "dict": cls._construct_model_from_dict,

View File

@ -101,18 +101,28 @@ def register_reverse_model_fields(
) -> None: ) -> None:
if issubclass(model_field, ManyToManyField): if issubclass(model_field, ManyToManyField):
model.Meta.model_fields[child_model_name] = ManyToMany( model.Meta.model_fields[child_model_name] = ManyToMany(
child, through=model_field.through, name=child_model_name, virtual=True child,
through=model_field.through,
name=child_model_name,
virtual=True,
related_name=model_field.name,
) )
# register foreign keys on through model # register foreign keys on through model
adjust_through_many_to_many_model(model, child, model_field) adjust_through_many_to_many_model(model, child, model_field, child_model_name)
else: else:
model.Meta.model_fields[child_model_name] = ForeignKey( model.Meta.model_fields[child_model_name] = ForeignKey(
child, real_name=child_model_name, virtual=True child,
real_name=child_model_name,
virtual=True,
related_name=model_field.name,
) )
def adjust_through_many_to_many_model( def adjust_through_many_to_many_model(
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] model: Type["Model"],
child: Type["Model"],
model_field: Type[ManyToManyField],
child_model_name: str,
) -> None: ) -> None:
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
model, real_name=model.get_name(), ondelete="CASCADE" model, real_name=model.get_name(), ondelete="CASCADE"

View File

@ -62,10 +62,17 @@ class ModelTableProxy:
@staticmethod @staticmethod
def get_clause_target_and_filter_column_name( def get_clause_target_and_filter_column_name(
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
related: str,
) -> Tuple[Type["Model"], str]: ) -> Tuple[Type["Model"], str]:
if reverse: if reverse:
field = target_model.resolve_relation_field(target_model, parent_model) field_name = (
parent_model.Meta.model_fields[related].related_name
or parent_model.get_name() + "s"
)
field = target_model.Meta.model_fields[field_name]
if issubclass(field, ormar.fields.ManyToManyField): if issubclass(field, ormar.fields.ManyToManyField):
sub_field = target_model.resolve_relation_field( sub_field = target_model.resolve_relation_field(
field.through, parent_model field.through, parent_model

View File

@ -262,9 +262,16 @@ class SqlJoin:
model_cls: Type["Model"], model_cls: Type["Model"],
part: str, part: str,
) -> Tuple[str, str]: ) -> Tuple[str, str]:
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi: if is_multi:
to_field = model_cls.resolve_relation_name( to_field = join_params.prev_model.get_name()
model_cls, join_params.prev_model to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname
)
elif join_params.prev_model.Meta.model_fields[part].virtual:
to_field = (
join_params.prev_model.Meta.model_fields[part].related_name
or join_params.prev_model.get_name() + "s"
) )
to_key = model_cls.get_column_alias(to_field) to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias( from_key = join_params.prev_model.get_column_alias(

View File

@ -145,7 +145,11 @@ class PrefetchQuery:
) )
def _get_filter_for_prefetch( def _get_filter_for_prefetch(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, self,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
related: str,
) -> List: ) -> List:
ids = self._extract_required_ids( ids = self._extract_required_ids(
parent_model=parent_model, target_model=target_model, reverse=reverse, parent_model=parent_model, target_model=target_model, reverse=reverse,
@ -155,7 +159,10 @@ class PrefetchQuery:
clause_target, clause_target,
filter_column, filter_column,
) = parent_model.get_clause_target_and_filter_column_name( ) = parent_model.get_clause_target_and_filter_column_name(
parent_model=parent_model, target_model=target_model, reverse=reverse parent_model=parent_model,
target_model=target_model,
reverse=reverse,
related=related,
) )
qryclause = QueryClause( qryclause = QueryClause(
model_cls=clause_target, select_related=[], filter_clauses=[], model_cls=clause_target, select_related=[], filter_clauses=[],
@ -246,7 +253,10 @@ class PrefetchQuery:
parent_model = target_model parent_model = target_model
filter_clauses = self._get_filter_for_prefetch( filter_clauses = self._get_filter_for_prefetch(
parent_model=parent_model, target_model=target_field.to, reverse=reverse, parent_model=parent_model,
target_model=target_field.to,
reverse=reverse,
related=related,
) )
if not filter_clauses: # related field is empty if not filter_clauses: # related field is empty
return return

View File

@ -25,6 +25,7 @@ class Relation:
self, self,
manager: "RelationsManager", manager: "RelationsManager",
type_: RelationType, type_: RelationType,
field_name: str,
to: Type["T"], to: Type["T"],
through: Type["T"] = None, through: Type["T"] = None,
) -> None: ) -> None:
@ -34,8 +35,9 @@ class Relation:
self._to_remove: Set = set() self._to_remove: Set = set()
self.to: Type["T"] = to self.to: Type["T"] = to
self.through: Optional[Type["T"]] = through self.through: Optional[Type["T"]] = through
self.field_name = field_name
self.related_models: Optional[Union[RelationProxy, "T"]] = ( self.related_models: Optional[Union[RelationProxy, "T"]] = (
RelationProxy(relation=self, type_=type_) RelationProxy(relation=self, type_=type_, field_name=field_name)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
) )
@ -47,7 +49,10 @@ class Relation:
if i not in self._to_remove if i not in self._to_remove
] ]
self.related_models = RelationProxy( self.related_models = RelationProxy(
relation=self, type_=self._type, data_=cleaned_data relation=self,
type_=self._type,
field_name=self.field_name,
data_=cleaned_data,
) )
relation_name = self._owner.resolve_relation_name(self._owner, self.to) relation_name = self._owner.resolve_relation_name(self._owner, self.to)
self._owner.__dict__[relation_name] = cleaned_data self._owner.__dict__[relation_name] = cleaned_data
@ -69,7 +74,7 @@ class Relation:
return None return None
def add(self, child: "T") -> None: def add(self, child: "T") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child) relation_name = self.field_name
if self._type == RelationType.PRIMARY: if self._type == RelationType.PRIMARY:
self.related_models = child self.related_models = child
self._owner.__dict__[relation_name] = child self._owner.__dict__[relation_name] = child
@ -84,7 +89,7 @@ class Relation:
self._owner.__dict__[relation_name] = rel self._owner.__dict__[relation_name] = rel
def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None: def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child) relation_name = self.field_name
if self._type == RelationType.PRIMARY: if self._type == RelationType.PRIMARY:
if self.related_models == child: if self.related_models == child:
self.related_models = None self.related_models = None

View File

@ -36,6 +36,7 @@ class RelationsManager:
self._relations[field.name] = Relation( self._relations[field.name] = Relation(
manager=self, manager=self,
type_=self._get_relation_type(field), type_=self._get_relation_type(field),
field_name=field.name,
to=field.to, to=field.to,
through=getattr(field, "through", None), through=getattr(field, "through", None),
) )
@ -64,15 +65,17 @@ class RelationsManager:
relation_name: str, relation_name: str,
) -> None: ) -> None:
to_field: Type[BaseField] = child.Meta.model_fields[relation_name] to_field: Type[BaseField] = child.Meta.model_fields[relation_name]
# print('comming', child_name, relation_name)
(parent, child, child_name, to_name,) = get_relations_sides_and_names( (parent, child, child_name, to_name,) = get_relations_sides_and_names(
to_field, parent, child, child_name, virtual to_field, parent, child, child_name, virtual
) )
# print('adding', parent.get_name(), child.get_name(), child_name)
parent_relation = parent._orm._get(child_name) parent_relation = parent._orm._get(child_name)
if parent_relation: if parent_relation:
parent_relation.add(child) # type: ignore parent_relation.add(child) # type: ignore
# print('adding', child.get_name(), parent.get_name(), child_name)
child_relation = child._orm._get(to_name) child_relation = child._orm._get(to_name)
if child_relation: if child_relation:
child_relation.add(parent) child_relation.add(parent)

View File

@ -1,4 +1,4 @@
from typing import Any, TYPE_CHECKING from typing import Any, Optional, TYPE_CHECKING
import ormar import ormar
from ormar.exceptions import NoMatch, RelationshipInstanceError from ormar.exceptions import NoMatch, RelationshipInstanceError
@ -12,13 +12,29 @@ if TYPE_CHECKING: # pragma no cover
class RelationProxy(list): class RelationProxy(list):
def __init__( def __init__(
self, relation: "Relation", type_: "RelationType", data_: Any = None self,
relation: "Relation",
type_: "RelationType",
field_name: str,
data_: Any = None,
) -> None: ) -> None:
super().__init__(data_ or ()) super().__init__(data_ or ())
self.relation: "Relation" = relation self.relation: "Relation" = relation
self.type_: "RelationType" = type_ self.type_: "RelationType" = type_
self.field_name = field_name
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_) self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_)
self._related_field_name: Optional[str] = None
@property
def related_field_name(self) -> str:
if self._related_field_name:
return self._related_field_name
owner_field = self._owner.Meta.model_fields[self.field_name]
self._related_field_name = (
owner_field.related_name or self._owner.get_name() + "s"
)
return self._related_field_name
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]: if item in ["count", "clear"]:
@ -48,9 +64,8 @@ class RelationProxy(list):
) )
def _set_queryset(self) -> "QuerySet": def _set_queryset(self) -> "QuerySet":
related_field = self._owner.resolve_relation_field( related_field_name = self.related_field_name
self.relation.to, self._owner related_field = self.relation.to.Meta.model_fields[related_field_name]
)
pkname = self._owner.get_column_alias(self._owner.Meta.pkname) pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
self._check_if_model_saved() self._check_if_model_saved()
kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk} kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk}
@ -70,11 +85,11 @@ class RelationProxy(list):
f"{item.get_name()} with given primary key!" f"{item.get_name()} with given primary key!"
) )
super().remove(item) super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner) relation_name = self.related_field_name
relation = item._orm._get(rel_name) relation = item._orm._get(relation_name)
if relation is None: # pragma nocover if relation is None: # pragma nocover
raise ValueError( raise ValueError(
f"{self._owner.get_name()} does not have relation {rel_name}" f"{self._owner.get_name()} does not have relation {relation_name}"
) )
relation.remove(self._owner) relation.remove(self._owner)
self.relation.remove(item) self.relation.remove(item)
@ -82,20 +97,17 @@ class RelationProxy(list):
await self.queryset_proxy.delete_through_instance(item) await self.queryset_proxy.delete_through_instance(item)
else: else:
if keep_reversed: if keep_reversed:
setattr(item, rel_name, None) setattr(item, relation_name, None)
await item.update() await item.update()
else: else:
await item.delete() await item.delete()
async def add(self, item: "Model") -> None: async def add(self, item: "Model") -> None:
relation_name = self.related_field_name
if self.type_ == ormar.RelationType.MULTIPLE: if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item) await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner) setattr(item, relation_name, self._owner)
setattr(item, rel_name, self._owner)
else: else:
self._check_if_model_saved() self._check_if_model_saved()
related_field = self._owner.resolve_relation_field( setattr(item, relation_name, self._owner)
self.relation.to, self._owner
)
setattr(item, related_field.name, self._owner)
await item.update() await item.update()

View File

@ -0,0 +1,100 @@
from typing import List, Optional
import databases
import sqlalchemy
from sqlalchemy import create_engine
import ormar
import pytest
from tests.settings import DATABASE_URL
db = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class User(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_users"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50)
class Signup(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_signup"
id: int = ormar.Integer(primary_key=True)
class Session(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_sessions"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=255, index=True)
some_text: str = ormar.Text()
some_other_text: Optional[str] = ormar.Text(nullable=True)
teacher: Optional[User] = ormar.ForeignKey(
User, nullable=True, related_name="teaching"
)
students: Optional[List[User]] = ormar.ManyToMany(
User, through=Signup, related_name="attending"
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_add_students():
async with db:
for user_id in [1, 2, 3, 4, 5]:
await User.objects.create(name=f"User {user_id}")
for name, some_text, some_other_text in [
("Session 1", "Some text 1", "Some other text 1"),
("Session 2", "Some text 2", "Some other text 2"),
("Session 3", "Some text 3", "Some other text 3"),
("Session 4", "Some text 4", "Some other text 4"),
("Session 5", "Some text 5", "Some other text 5"),
]:
await Session(
name=name, some_text=some_text, some_other_text=some_other_text
).save()
s1 = await Session.objects.get(pk=1)
s2 = await Session.objects.get(pk=2)
users = {}
for i in range(1, 6):
user = await User.objects.get(pk=i)
users[f"user_{i}"] = user
if i % 2 == 0:
await s1.students.add(user)
else:
await s2.students.add(user)
assert len(s1.students) > 0
assert len(s2.students) > 0
user = await User.objects.select_related("attending").get(pk=1)
assert user.attending is not None
assert len(user.attending) > 0
query = Session.objects.prefetch_related(["students", "teacher",])
sessions = await query.all()
assert len(sessions) == 5