first passing to clean and check

This commit is contained in:
collerek
2020-12-15 11:55:07 +01:00
parent eeee0409ac
commit 1b42d321b9
8 changed files with 136 additions and 27 deletions

View File

@ -164,7 +164,7 @@ class ForeignKeyField(BaseField):
) -> Optional[Union["Model", List["Model"]]]:
if value is None:
return None if not cls.virtual else []
print('expanding', relation_name)
constructors = {
f"{cls.to.__name__}": cls._register_existing_model,
"dict": cls._construct_model_from_dict,

View File

@ -89,13 +89,13 @@ def register_reverse_model_fields(
) -> None:
if issubclass(model_field, ManyToManyField):
model.Meta.model_fields[child_model_name] = ManyToMany(
child, through=model_field.through, name=child_model_name, virtual=True
child, through=model_field.through, name=child_model_name, virtual=True, related_name=model_field.name
)
# register foreign keys on through model
adjust_through_many_to_many_model(model, child, model_field)
else:
model.Meta.model_fields[child_model_name] = ForeignKey(
child, real_name=child_model_name, virtual=True
child, real_name=child_model_name, virtual=True, related_name=model_field.name
)

View File

@ -62,10 +62,11 @@ class ModelTableProxy:
@staticmethod
def get_clause_target_and_filter_column_name(
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool
parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, related: str,
) -> Tuple[Type["Model"], str]:
if reverse:
field = target_model.resolve_relation_field(target_model, parent_model)
field = parent_model.Meta.model_fields[related]
# field = target_model.resolve_relation_field(target_model, parent_model)
if issubclass(field, ormar.fields.ManyToManyField):
sub_field = target_model.resolve_relation_field(
field.through, parent_model

View File

@ -145,7 +145,7 @@ class PrefetchQuery:
)
def _get_filter_for_prefetch(
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool,
self, parent_model: Type["Model"], target_model: Type["Model"], reverse: bool, related: str,
) -> List:
ids = self._extract_required_ids(
parent_model=parent_model, target_model=target_model, reverse=reverse,
@ -155,7 +155,7 @@ class PrefetchQuery:
clause_target,
filter_column,
) = parent_model.get_clause_target_and_filter_column_name(
parent_model=parent_model, target_model=target_model, reverse=reverse
parent_model=parent_model, target_model=target_model, reverse=reverse, related=related
)
qryclause = QueryClause(
model_cls=clause_target, select_related=[], filter_clauses=[],
@ -246,7 +246,7 @@ class PrefetchQuery:
parent_model = target_model
filter_clauses = self._get_filter_for_prefetch(
parent_model=parent_model, target_model=target_field.to, reverse=reverse,
parent_model=parent_model, target_model=target_field.to, reverse=reverse, related=related
)
if not filter_clauses: # related field is empty
return

View File

@ -25,6 +25,7 @@ class Relation:
self,
manager: "RelationsManager",
type_: RelationType,
field_name: str,
to: Type["T"],
through: Type["T"] = None,
) -> None:
@ -34,8 +35,9 @@ class Relation:
self._to_remove: Set = set()
self.to: Type["T"] = to
self.through: Optional[Type["T"]] = through
self.field_name = field_name
self.related_models: Optional[Union[RelationProxy, "T"]] = (
RelationProxy(relation=self, type_=type_)
RelationProxy(relation=self, type_=type_, field_name=field_name)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None
)
@ -47,7 +49,7 @@ class Relation:
if i not in self._to_remove
]
self.related_models = RelationProxy(
relation=self, type_=self._type, data_=cleaned_data
relation=self, type_=self._type, field_name=self.field_name, data_=cleaned_data
)
relation_name = self._owner.resolve_relation_name(self._owner, self.to)
self._owner.__dict__[relation_name] = cleaned_data
@ -69,7 +71,7 @@ class Relation:
return None
def add(self, child: "T") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
relation_name = self.field_name
if self._type == RelationType.PRIMARY:
self.related_models = child
self._owner.__dict__[relation_name] = child
@ -84,7 +86,7 @@ class Relation:
self._owner.__dict__[relation_name] = rel
def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
relation_name = self.field_name
if self._type == RelationType.PRIMARY:
if self.related_models == child:
self.related_models = None

View File

@ -36,6 +36,7 @@ class RelationsManager:
self._relations[field.name] = Relation(
manager=self,
type_=self._get_relation_type(field),
field_name=field.name,
to=field.to,
through=getattr(field, "through", None),
)
@ -64,15 +65,17 @@ class RelationsManager:
relation_name: str,
) -> None:
to_field: Type[BaseField] = child.Meta.model_fields[relation_name]
print('comming', child_name, relation_name)
(parent, child, child_name, to_name,) = get_relations_sides_and_names(
to_field, parent, child, child_name, virtual
)
print('adding', parent.get_name(), child.get_name(), child_name)
parent_relation = parent._orm._get(child_name)
if parent_relation:
parent_relation.add(child) # type: ignore
print('adding', child.get_name(), parent.get_name(), child_name)
child_relation = child._orm._get(to_name)
if child_relation:
child_relation.add(parent)

View File

@ -12,13 +12,16 @@ if TYPE_CHECKING: # pragma no cover
class RelationProxy(list):
def __init__(
self, relation: "Relation", type_: "RelationType", data_: Any = None
self, relation: "Relation", type_: "RelationType", field_name: str, data_: Any = None
) -> None:
super().__init__(data_ or ())
self.relation: "Relation" = relation
self.type_: "RelationType" = type_
self.field_name = field_name
self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_)
owner_field = self._owner.Meta.model_fields[self.field_name]
self.related_field_name = owner_field.related_name or self._owner.get_name() + 's'
def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]:
@ -48,9 +51,8 @@ class RelationProxy(list):
)
def _set_queryset(self) -> "QuerySet":
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
related_field_name = self.related_field_name
related_field = self.relation.to.Meta.model_fields[related_field_name]
pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
self._check_if_model_saved()
kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk}
@ -70,11 +72,11 @@ class RelationProxy(list):
f"{item.get_name()} with given primary key!"
)
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
relation = item._orm._get(rel_name)
relation_name = self.related_field_name
relation = item._orm._get(relation_name)
if relation is None: # pragma nocover
raise ValueError(
f"{self._owner.get_name()} does not have relation {rel_name}"
f"{self._owner.get_name()} does not have relation {relation_name}"
)
relation.remove(self._owner)
self.relation.remove(item)
@ -82,20 +84,17 @@ class RelationProxy(list):
await self.queryset_proxy.delete_through_instance(item)
else:
if keep_reversed:
setattr(item, rel_name, None)
setattr(item, relation_name, None)
await item.update()
else:
await item.delete()
async def add(self, item: "Model") -> None:
relation_name = self.related_field_name
if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
setattr(item, relation_name, self._owner)
else:
self._check_if_model_saved()
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
setattr(item, related_field.name, self._owner)
setattr(item, relation_name, self._owner)
await item.update()

View File

@ -0,0 +1,104 @@
from typing import List, Optional
import databases
import sqlalchemy
from sqlalchemy import create_engine
import ormar
import pytest
from tests.settings import DATABASE_URL
db = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class User(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_users"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50)
class Signup(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_signup"
id: int = ormar.Integer(primary_key=True)
class Session(ormar.Model):
class Meta:
metadata = metadata
database = db
tablename = "test_sessions"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=255, index=True)
some_text: str = ormar.Text()
some_other_text: Optional[str] = ormar.Text(nullable=True)
teacher: Optional[User] = ormar.ForeignKey(
User, nullable=True, related_name="teaching"
)
students: Optional[List[User]] = ormar.ManyToMany(
User, through=Signup, related_name="attending"
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_add_students():
for user_id in [1, 2, 3, 4, 5]:
await User.objects.create(name=f"User {user_id}")
for name, some_text, some_other_text in [
("Session 1", "Some text 1", "Some other text 1"),
("Session 2", "Some text 2", "Some other text 2"),
("Session 3", "Some text 3", "Some other text 3"),
("Session 4", "Some text 4", "Some other text 4"),
("Session 5", "Some text 5", "Some other text 5"),
]:
await Session(
name=name, some_text=some_text, some_other_text=some_other_text
).save()
s1 = await Session.objects.get(pk=1)
s2 = await Session.objects.get(pk=2)
users = {}
for i in range(1, 6):
user = await User.objects.get(pk=i)
users[f"user_{i}"] = user
if i % 2 == 0:
await s1.students.add(user)
else:
await s2.students.add(user)
assert len(s1.students) > 0
assert len(s2.students) > 0
user = await User.objects.select_related("attending").get(pk=1)
assert user.attending is not None
assert len(user.attending) > 0
query = Session.objects.prefetch_related(
[
"students",
"teacher",
]
)
sessions = await query.all()
assert len(sessions) == 5