first passing to clean and check

This commit is contained in:
collerek
2020-12-15 11:55:07 +01:00
parent eeee0409ac
commit 1b42d321b9
8 changed files with 136 additions and 27 deletions

View File

@ -164,7 +164,7 @@ class ForeignKeyField(BaseField):
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if value is None: if value is None:
return None if not cls.virtual else [] return None if not cls.virtual else []
print('expanding', relation_name)
constructors = { constructors = {
f"{cls.to.__name__}": cls._register_existing_model, f"{cls.to.__name__}": cls._register_existing_model,
"dict": cls._construct_model_from_dict, "dict": cls._construct_model_from_dict,

View File

@ -89,13 +89,13 @@ def register_reverse_model_fields(
) -> None: ) -> None:
if issubclass(model_field, ManyToManyField): if issubclass(model_field, ManyToManyField):
model.Meta.model_fields[child_model_name] = ManyToMany( model.Meta.model_fields[child_model_name] = ManyToMany(
child, through=model_field.through, name=child_model_name, virtual=True child, through=model_field.through, name=child_model_name, virtual=True, related_name=model_field.name
) )
# register foreign keys on through model # register foreign keys on through model
adjust_through_many_to_many_model(model, child, model_field) adjust_through_many_to_many_model(model, child, model_field)
else: else:
model.Meta.model_fields[child_model_name] = ForeignKey( model.Meta.model_fields[child_model_name] = ForeignKey(
child, real_name=child_model_name, virtual=True child, real_name=child_model_name, virtual=True, related_name=model_field.name
) )

View File

@ -62,10 +62,11 @@ class ModelTableProxy:
@staticmethod @staticmethod
def get_clause_target_and_filter_column_name( def get_clause_target_and_filter_column_name(
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, related: str,
) -> Tuple[Type["Model"], str]: ) -> Tuple[Type["Model"], str]:
if reverse: if reverse:
field = target_model.resolve_relation_field(target_model, parent_model) field = parent_model.Meta.model_fields[related]
# field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ormar.fields.ManyToManyField): if issubclass(field, ormar.fields.ManyToManyField):
sub_field = target_model.resolve_relation_field( sub_field = target_model.resolve_relation_field(
field.through, parent_model field.through, parent_model

View File

@ -145,7 +145,7 @@ class PrefetchQuery:
) )
def _get_filter_for_prefetch( def _get_filter_for_prefetch(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, related: str,
) -> List: ) -> List:
ids = self._extract_required_ids( ids = self._extract_required_ids(
parent_model=parent_model, target_model=target_model, reverse=reverse, parent_model=parent_model, target_model=target_model, reverse=reverse,
@ -155,7 +155,7 @@ class PrefetchQuery:
clause_target, clause_target,
filter_column, filter_column,
) = parent_model.get_clause_target_and_filter_column_name( ) = parent_model.get_clause_target_and_filter_column_name(
parent_model=parent_model, target_model=target_model, reverse=reverse parent_model=parent_model, target_model=target_model, reverse=reverse, related=related
) )
qryclause = QueryClause( qryclause = QueryClause(
model_cls=clause_target, select_related=[], filter_clauses=[], model_cls=clause_target, select_related=[], filter_clauses=[],
@ -246,7 +246,7 @@ class PrefetchQuery:
parent_model = target_model parent_model = target_model
filter_clauses = self._get_filter_for_prefetch( filter_clauses = self._get_filter_for_prefetch(
parent_model=parent_model, target_model=target_field.to, reverse=reverse, parent_model=parent_model, target_model=target_field.to, reverse=reverse, related=related
) )
if not filter_clauses: # related field is empty if not filter_clauses: # related field is empty
return return

View File

@ -25,6 +25,7 @@ class Relation:
self, self,
manager: "RelationsManager", manager: "RelationsManager",
type_: RelationType, type_: RelationType,
field_name: str,
to: Type["T"], to: Type["T"],
through: Type["T"] = None, through: Type["T"] = None,
) -> None: ) -> None:
@ -34,8 +35,9 @@ class Relation:
self._to_remove: Set = set() self._to_remove: Set = set()
self.to: Type["T"] = to self.to: Type["T"] = to
self.through: Optional[Type["T"]] = through self.through: Optional[Type["T"]] = through
self.field_name = field_name
self.related_models: Optional[Union[RelationProxy, "T"]] = ( self.related_models: Optional[Union[RelationProxy, "T"]] = (
RelationProxy(relation=self, type_=type_) RelationProxy(relation=self, type_=type_, field_name=field_name)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
) )
@ -47,7 +49,7 @@ class Relation:
if i not in self._to_remove if i not in self._to_remove
] ]
self.related_models = RelationProxy( self.related_models = RelationProxy(
relation=self, type_=self._type, data_=cleaned_data relation=self, type_=self._type, field_name=self.field_name, data_=cleaned_data
) )
relation_name = self._owner.resolve_relation_name(self._owner, self.to) relation_name = self._owner.resolve_relation_name(self._owner, self.to)
self._owner.__dict__[relation_name] = cleaned_data self._owner.__dict__[relation_name] = cleaned_data
@ -69,7 +71,7 @@ class Relation:
return None return None
def add(self, child: "T") -> None: def add(self, child: "T") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child) relation_name = self.field_name
if self._type == RelationType.PRIMARY: if self._type == RelationType.PRIMARY:
self.related_models = child self.related_models = child
self._owner.__dict__[relation_name] = child self._owner.__dict__[relation_name] = child
@ -84,7 +86,7 @@ class Relation:
self._owner.__dict__[relation_name] = rel self._owner.__dict__[relation_name] = rel
def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None: def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child) relation_name = self.field_name
if self._type == RelationType.PRIMARY: if self._type == RelationType.PRIMARY:
if self.related_models == child: if self.related_models == child:
self.related_models = None self.related_models = None

View File

@ -36,6 +36,7 @@ class RelationsManager:
self._relations[field.name] = Relation( self._relations[field.name] = Relation(
manager=self, manager=self,
type_=self._get_relation_type(field), type_=self._get_relation_type(field),
field_name=field.name,
to=field.to, to=field.to,
through=getattr(field, "through", None), through=getattr(field, "through", None),
) )
@ -64,15 +65,17 @@ class RelationsManager:
relation_name: str, relation_name: str,
) -> None: ) -> None:
to_field: Type[BaseField] = child.Meta.model_fields[relation_name] to_field: Type[BaseField] = child.Meta.model_fields[relation_name]
print('comming', child_name, relation_name)
(parent, child, child_name, to_name,) = get_relations_sides_and_names( (parent, child, child_name, to_name,) = get_relations_sides_and_names(
to_field, parent, child, child_name, virtual to_field, parent, child, child_name, virtual
) )
print('adding', parent.get_name(), child.get_name(), child_name)
parent_relation = parent._orm._get(child_name) parent_relation = parent._orm._get(child_name)
if parent_relation: if parent_relation:
parent_relation.add(child) # type: ignore parent_relation.add(child) # type: ignore
print('adding', child.get_name(), parent.get_name(), child_name)
child_relation = child._orm._get(to_name) child_relation = child._orm._get(to_name)
if child_relation: if child_relation:
child_relation.add(parent) child_relation.add(parent)

View File

@ -12,13 +12,16 @@ if TYPE_CHECKING: # pragma no cover
class RelationProxy(list): class RelationProxy(list):
def __init__( def __init__(
self, relation: "Relation", type_: "RelationType", data_: Any = None self, relation: "Relation", type_: "RelationType", field_name: str, data_: Any = None
) -> None: ) -> None:
super().__init__(data_ or ()) super().__init__(data_ or ())
self.relation: "Relation" = relation self.relation: "Relation" = relation
self.type_: "RelationType" = type_ self.type_: "RelationType" = type_
self.field_name = field_name
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_) self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_)
owner_field = self._owner.Meta.model_fields[self.field_name]
self.related_field_name = owner_field.related_name or self._owner.get_name() + 's'
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]: if item in ["count", "clear"]:
@ -48,9 +51,8 @@ class RelationProxy(list):
) )
def _set_queryset(self) -> "QuerySet": def _set_queryset(self) -> "QuerySet":
related_field = self._owner.resolve_relation_field( related_field_name = self.related_field_name
self.relation.to, self._owner related_field = self.relation.to.Meta.model_fields[related_field_name]
)
pkname = self._owner.get_column_alias(self._owner.Meta.pkname) pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
self._check_if_model_saved() self._check_if_model_saved()
kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk} kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk}
@ -70,11 +72,11 @@ class RelationProxy(list):
f"{item.get_name()} with given primary key!" f"{item.get_name()} with given primary key!"
) )
super().remove(item) super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner) relation_name = self.related_field_name
relation = item._orm._get(rel_name) relation = item._orm._get(relation_name)
if relation is None: # pragma nocover if relation is None: # pragma nocover
raise ValueError( raise ValueError(
f"{self._owner.get_name()} does not have relation {rel_name}" f"{self._owner.get_name()} does not have relation {relation_name}"
) )
relation.remove(self._owner) relation.remove(self._owner)
self.relation.remove(item) self.relation.remove(item)
@ -82,20 +84,17 @@ class RelationProxy(list):
await self.queryset_proxy.delete_through_instance(item) await self.queryset_proxy.delete_through_instance(item)
else: else:
if keep_reversed: if keep_reversed:
setattr(item, rel_name, None) setattr(item, relation_name, None)
await item.update() await item.update()
else: else:
await item.delete() await item.delete()
async def add(self, item: "Model") -> None: async def add(self, item: "Model") -> None:
relation_name = self.related_field_name
if self.type_ == ormar.RelationType.MULTIPLE: if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item) await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner) setattr(item, relation_name, self._owner)
setattr(item, rel_name, self._owner)
else: else:
self._check_if_model_saved() self._check_if_model_saved()
related_field = self._owner.resolve_relation_field( setattr(item, relation_name, self._owner)
self.relation.to, self._owner
)
setattr(item, related_field.name, self._owner)
await item.update() await item.update()

View File

@ -0,0 +1,104 @@
from typing import List, Optional
import databases
import sqlalchemy
from sqlalchemy import create_engine
import ormar
import pytest
from tests.settings import DATABASE_URL
db = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class User(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_users"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50)
class Signup(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_signup"
id: int = ormar.Integer(primary_key=True)
class Session(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_sessions"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=255, index=True)
some_text: str = ormar.Text()
some_other_text: Optional[str] = ormar.Text(nullable=True)
teacher: Optional[User] = ormar.ForeignKey(
User, nullable=True, related_name="teaching"
)
students: Optional[List[User]] = ormar.ManyToMany(
User, through=Signup, related_name="attending"
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_add_students():
for user_id in [1, 2, 3, 4, 5]:
await User.objects.create(name=f"User {user_id}")
for name, some_text, some_other_text in [
("Session 1", "Some text 1", "Some other text 1"),
("Session 2", "Some text 2", "Some other text 2"),
("Session 3", "Some text 3", "Some other text 3"),
("Session 4", "Some text 4", "Some other text 4"),
("Session 5", "Some text 5", "Some other text 5"),
]:
await Session(
name=name, some_text=some_text, some_other_text=some_other_text
).save()
s1 = await Session.objects.get(pk=1)
s2 = await Session.objects.get(pk=2)
users = {}
for i in range(1, 6):
user = await User.objects.get(pk=i)
users[f"user_{i}"] = user
if i % 2 == 0:
await s1.students.add(user)
else:
await s2.students.add(user)
assert len(s1.students) > 0
assert len(s2.students) > 0
user = await User.objects.select_related("attending").get(pk=1)
assert user.attending is not None
assert len(user.attending) > 0
query = Session.objects.prefetch_related(
[
"students",
"teacher",
]
)
sessions = await query.all()
assert len(sessions) == 5