fix confligs after merge from master

This commit is contained in:
collerek
2020-12-15 14:43:15 +01:00
11 changed files with 191 additions and 34 deletions

View File

@ -1,3 +1,7 @@
# 0.7.5
* Fix for wrong relation column name in many_to_many relation joins (fix [#73][#73])
# 0.7.4
* Allow multiple relations to the same related model/table.

View File

@ -44,7 +44,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType()
__version__ = "0.7.4"
__version__ = "0.7.5"
__all__ = [
"Integer",
"BigInteger",

View File

@ -322,7 +322,6 @@ class ForeignKeyField(BaseField):
"""
if value is None:
return None if not cls.virtual else []
constructors = {
f"{cls.to.__name__}": cls._register_existing_model,
"dict": cls._construct_model_from_dict,

View File

@ -101,18 +101,28 @@ def register_reverse_model_fields(
) -> None:
if issubclass(model_field, ManyToManyField):
model.Meta.model_fields[child_model_name] = ManyToMany(
child, through=model_field.through, name=child_model_name, virtual=True
child,
through=model_field.through,
name=child_model_name,
virtual=True,
related_name=model_field.name,
)
# register foreign keys on through model
adjust_through_many_to_many_model(model, child, model_field)
adjust_through_many_to_many_model(model, child, model_field, child_model_name)
else:
model.Meta.model_fields[child_model_name] = ForeignKey(
child, real_name=child_model_name, virtual=True
child,
real_name=child_model_name,
virtual=True,
related_name=model_field.name,
)
def adjust_through_many_to_many_model(
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
model: Type["Model"],
child: Type["Model"],
model_field: Type[ManyToManyField],
child_model_name: str,
) -> None:
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
model, real_name=model.get_name(), ondelete="CASCADE"

View File

@ -62,10 +62,17 @@ class ModelTableProxy:
@staticmethod
def get_clause_target_and_filter_column_name(
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
related: str,
) -> Tuple[Type["Model"], str]:
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
field_name = (
parent_model.Meta.model_fields[related].related_name
or parent_model.get_name() + "s"
)
field = target_model.Meta.model_fields[field_name]
if issubclass(field, ormar.fields.ManyToManyField):
sub_field = target_model.resolve_relation_field(
field.through, parent_model

View File

@ -262,9 +262,16 @@ class SqlJoin:
model_cls: Type["Model"],
part: str,
) -> Tuple[str, str]:
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
to_field = model_cls.resolve_relation_name(
model_cls, join_params.prev_model
if is_multi:
to_field = join_params.prev_model.get_name()
to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname
)
elif join_params.prev_model.Meta.model_fields[part].virtual:
to_field = (
join_params.prev_model.Meta.model_fields[part].related_name
or join_params.prev_model.get_name() + "s"
)
to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias(

View File

@ -145,7 +145,11 @@ class PrefetchQuery:
)
def _get_filter_for_prefetch(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
self,
parent_model: Type["Model"],
target_model: Type["Model"],
reverse: bool,
related: str,
) -> List:
ids = self._extract_required_ids(
parent_model=parent_model, target_model=target_model, reverse=reverse,
@ -155,7 +159,10 @@ class PrefetchQuery:
clause_target,
filter_column,
) = parent_model.get_clause_target_and_filter_column_name(
parent_model=parent_model, target_model=target_model, reverse=reverse
parent_model=parent_model,
target_model=target_model,
reverse=reverse,
related=related,
)
qryclause = QueryClause(
model_cls=clause_target, select_related=[], filter_clauses=[],
@ -246,7 +253,10 @@ class PrefetchQuery:
parent_model = target_model
filter_clauses = self._get_filter_for_prefetch(
parent_model=parent_model, target_model=target_field.to, reverse=reverse,
parent_model=parent_model,
target_model=target_field.to,
reverse=reverse,
related=related,
)
if not filter_clauses: # related field is empty
return

View File

@ -25,6 +25,7 @@ class Relation:
self,
manager: "RelationsManager",
type_: RelationType,
field_name: str,
to: Type["T"],
through: Type["T"] = None,
) -> None:
@ -34,8 +35,9 @@ class Relation:
self._to_remove: Set = set()
self.to: Type["T"] = to
self.through: Optional[Type["T"]] = through
self.field_name = field_name
self.related_models: Optional[Union[RelationProxy, "T"]] = (
RelationProxy(relation=self, type_=type_)
RelationProxy(relation=self, type_=type_, field_name=field_name)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None
)
@ -47,7 +49,10 @@ class Relation:
if i not in self._to_remove
]
self.related_models = RelationProxy(
relation=self, type_=self._type, data_=cleaned_data
relation=self,
type_=self._type,
field_name=self.field_name,
data_=cleaned_data,
)
relation_name = self._owner.resolve_relation_name(self._owner, self.to)
self._owner.__dict__[relation_name] = cleaned_data
@ -69,7 +74,7 @@ class Relation:
return None
def add(self, child: "T") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
relation_name = self.field_name
if self._type == RelationType.PRIMARY:
self.related_models = child
self._owner.__dict__[relation_name] = child
@ -84,7 +89,7 @@ class Relation:
self._owner.__dict__[relation_name] = rel
def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
relation_name = self.field_name
if self._type == RelationType.PRIMARY:
if self.related_models == child:
self.related_models = None

View File

@ -36,6 +36,7 @@ class RelationsManager:
self._relations[field.name] = Relation(
manager=self,
type_=self._get_relation_type(field),
field_name=field.name,
to=field.to,
through=getattr(field, "through", None),
)
@ -64,15 +65,17 @@ class RelationsManager:
relation_name: str,
) -> None:
to_field: Type[BaseField] = child.Meta.model_fields[relation_name]
# print('comming', child_name, relation_name)
(parent, child, child_name, to_name,) = get_relations_sides_and_names(
to_field, parent, child, child_name, virtual
)
# print('adding', parent.get_name(), child.get_name(), child_name)
parent_relation = parent._orm._get(child_name)
if parent_relation:
parent_relation.add(child) # type: ignore
# print('adding', child.get_name(), parent.get_name(), child_name)
child_relation = child._orm._get(to_name)
if child_relation:
child_relation.add(parent)

View File

@ -1,4 +1,4 @@
from typing import Any, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import ormar
from ormar.exceptions import NoMatch, RelationshipInstanceError
@ -12,13 +12,29 @@ if TYPE_CHECKING: # pragma no cover
class RelationProxy(list):
def __init__(
self, relation: "Relation", type_: "RelationType", data_: Any = None
self,
relation: "Relation",
type_: "RelationType",
field_name: str,
data_: Any = None,
) -> None:
super().__init__(data_ or ())
self.relation: "Relation" = relation
self.type_: "RelationType" = type_
self.field_name = field_name
self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_)
self._related_field_name: Optional[str] = None
@property
def related_field_name(self) -> str:
if self._related_field_name:
return self._related_field_name
owner_field = self._owner.Meta.model_fields[self.field_name]
self._related_field_name = (
owner_field.related_name or self._owner.get_name() + "s"
)
return self._related_field_name
def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]:
@ -48,9 +64,8 @@ class RelationProxy(list):
)
def _set_queryset(self) -> "QuerySet":
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
related_field_name = self.related_field_name
related_field = self.relation.to.Meta.model_fields[related_field_name]
pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
self._check_if_model_saved()
kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk}
@ -70,11 +85,11 @@ class RelationProxy(list):
f"{item.get_name()} with given primary key!"
)
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
relation = item._orm._get(rel_name)
relation_name = self.related_field_name
relation = item._orm._get(relation_name)
if relation is None: # pragma nocover
raise ValueError(
f"{self._owner.get_name()} does not have relation {rel_name}"
f"{self._owner.get_name()} does not have relation {relation_name}"
)
relation.remove(self._owner)
self.relation.remove(item)
@ -82,20 +97,17 @@ class RelationProxy(list):
await self.queryset_proxy.delete_through_instance(item)
else:
if keep_reversed:
setattr(item, rel_name, None)
setattr(item, relation_name, None)
await item.update()
else:
await item.delete()
async def add(self, item: "Model") -> None:
relation_name = self.related_field_name
if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
setattr(item, relation_name, self._owner)
else:
self._check_if_model_saved()
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
setattr(item, related_field.name, self._owner)
setattr(item, relation_name, self._owner)
await item.update()

View File

@ -0,0 +1,100 @@
from typing import List, Optional
import databases
import sqlalchemy
from sqlalchemy import create_engine
import ormar
import pytest
from tests.settings import DATABASE_URL
db = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class User(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_users"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50)
class Signup(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_signup"
id: int = ormar.Integer(primary_key=True)
class Session(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_sessions"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=255, index=True)
some_text: str = ormar.Text()
some_other_text: Optional[str] = ormar.Text(nullable=True)
teacher: Optional[User] = ormar.ForeignKey(
User, nullable=True, related_name="teaching"
)
students: Optional[List[User]] = ormar.ManyToMany(
User, through=Signup, related_name="attending"
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_add_students():
async with db:
for user_id in [1, 2, 3, 4, 5]:
await User.objects.create(name=f"User {user_id}")
for name, some_text, some_other_text in [
("Session 1", "Some text 1", "Some other text 1"),
("Session 2", "Some text 2", "Some other text 2"),
("Session 3", "Some text 3", "Some other text 3"),
("Session 4", "Some text 4", "Some other text 4"),
("Session 5", "Some text 5", "Some other text 5"),
]:
await Session(
name=name, some_text=some_text, some_other_text=some_other_text
).save()
s1 = await Session.objects.get(pk=1)
s2 = await Session.objects.get(pk=2)
users = {}
for i in range(1, 6):
user = await User.objects.get(pk=i)
users[f"user_{i}"] = user
if i % 2 == 0:
await s1.students.add(user)
else:
await s2.students.add(user)
assert len(s1.students) > 0
assert len(s2.students) > 0
user = await User.objects.select_related("attending").get(pk=1)
assert user.attending is not None
assert len(user.attending) > 0
query = Session.objects.prefetch_related(["students", "teacher",])
sessions = await query.all()
assert len(sessions) == 5