Merge pull request #61 from collerek/fk_queryset

Expose QuerysetProxy on reverse ForeignKey, add more QuerySet methods to QuerysetProxy
This commit is contained in:
collerek
2020-12-01 16:51:57 +07:00
committed by GitHub
20 changed files with 961 additions and 139 deletions

View File

@ -154,7 +154,7 @@ assert len(tracks) == 1
* `bulk_create(objects: List[Model]) -> None` * `bulk_create(objects: List[Model]) -> None`
* `bulk_update(objects: List[Model], columns: List[str] = None) -> None` * `bulk_update(objects: List[Model], columns: List[str] = None) -> None`
* `delete(each: bool = False, **kwargs) -> int` * `delete(each: bool = False, **kwargs) -> int`
* `all(self, **kwargs) -> List[Optional[Model]]` * `all(**kwargs) -> List[Optional[Model]]`
* `filter(**kwargs) -> QuerySet` * `filter(**kwargs) -> QuerySet`
* `exclude(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet`
* `select_related(related: Union[List, str]) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet`

View File

@ -154,7 +154,7 @@ assert len(tracks) == 1
* `bulk_create(objects: List[Model]) -> None` * `bulk_create(objects: List[Model]) -> None`
* `bulk_update(objects: List[Model], columns: List[str] = None) -> None` * `bulk_update(objects: List[Model], columns: List[str] = None) -> None`
* `delete(each: bool = False, **kwargs) -> int` * `delete(each: bool = False, **kwargs) -> int`
* `all(self, **kwargs) -> List[Optional[Model]]` * `all(**kwargs) -> List[Optional[Model]]`
* `filter(**kwargs) -> QuerySet` * `filter(**kwargs) -> QuerySet`
* `exclude(**kwargs) -> QuerySet` * `exclude(**kwargs) -> QuerySet`
* `select_related(related: Union[List, str]) -> QuerySet` * `select_related(related: Union[List, str]) -> QuerySet`

View File

@ -176,7 +176,7 @@ Return number of rows deleted.
### all ### all
`all(self, **kwargs) -> List[Optional["Model"]]` `all(**kwargs) -> List[Optional["Model"]]`
Returns all rows from a database for given model for set filter options. Returns all rows from a database for given model for set filter options.
@ -212,7 +212,7 @@ You can use special filter suffix to change the filter operands:
* exact - like `album__name__exact='Malibu'` (exact match) * exact - like `album__name__exact='Malibu'` (exact match)
* iexact - like `album__name__iexact='malibu'` (exact match case insensitive) * iexact - like `album__name__iexact='malibu'` (exact match case insensitive)
* contains - like `album__name__conatins='Mal'` (sql like) * contains - like `album__name__contains='Mal'` (sql like)
* icontains - like `album__name__icontains='mal'` (sql like case insensitive) * icontains - like `album__name__icontains='mal'` (sql like case insensitive)
* in - like `album__name__in=['Malibu', 'Barclay']` (sql in) * in - like `album__name__in=['Malibu', 'Barclay']` (sql in)
* gt - like `position__gt=3` (sql >) * gt - like `position__gt=3` (sql >)

View File

@ -29,6 +29,83 @@ By default it's child (source) `Model` name + s, like courses in snippet below:
--8<-- "../docs_src/fields/docs001.py" --8<-- "../docs_src/fields/docs001.py"
``` ```
Reverse relation exposes API to manage related objects also from parent side.
##### add
Adding child model from parent side causes adding related model to currently loaded parent relation,
as well as sets child's model foreign key value and updates the model.
```python
department = await Department(name="Science").save()
course = Course(name="Math", completed=False) # note - not saved
await department.courses.add(course)
assert course.pk is not None # child model was saved
# relation on child model is set and FK column saved in db
assert courses.department == department
# relation on parent model is also set
assert department.courses[0] == course
```
!!!warning
If you want to add child model on related model the primary key value for parent model **has to exist in database**.
Otherwise ormar will raise RelationshipInstanceError as it cannot set child's ForeignKey column value
if parent model has no primary key value.
That means that in example above the department has to be saved before you can call `department.courses.add()`.
##### remove
Removal of the related model one by one.
In reverse relation calling `remove()` does not remove the child model, but instead nulls it ForeignKey value.
```python
# continuing from above
await department.courses.remove(course)
assert len(department.courses) == 0
# course still exists and was saved in remove
assert course.pk is not None
assert course.department is None
# to remove child from db
await course.delete()
```
But if you want to clear the relation and delete the child at the same time you can issue:
```python
# this will not only clear the relation
# but also delete related course from db
await department.courses.remove(course, keep_reversed=False)
```
##### clear
Removal of all related models in one call.
Like remove by default `clear()` nulls the ForeigKey column on child model (all, not matter if they are loaded or not).
```python
# nulls department column on all courses related to this department
await department.courses.clear()
```
If you want to remove the children altogether from the database, set `keep_reversed=False`
```python
# deletes from db all courses related to this department
await department.courses.clear(keep_reversed=False)
```
##### QuerysetProxy
Reverse relation exposes QuerysetProxy API that allows you to query related model like you would issue a normal Query.
To read which methods of QuerySet are available read below [querysetproxy][querysetproxy]
#### related_name #### related_name
But you can overwrite this name by providing `related_name` parameter like below: But you can overwrite this name by providing `related_name` parameter like below:
@ -94,7 +171,7 @@ Sqlalchemy column and Type are automatically taken from target `Model`.
* Sqlalchemy column: class of a target `Model` primary key column * Sqlalchemy column: class of a target `Model` primary key column
* Type (used for pydantic): type of a target `Model` * Type (used for pydantic): type of a target `Model`
####Defining `Models`: ####Defining `Models`
```Python ```Python
--8<-- "../docs_src/relations/docs002.py" --8<-- "../docs_src/relations/docs002.py"
@ -107,7 +184,7 @@ post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News") news = await Category.objects.create(name="News")
``` ```
#### Adding related models #### add
```python ```python
# Add a category to a post. # Add a category to a post.
@ -121,7 +198,110 @@ await news.posts.add(post)
Otherwise an IntegrityError will be raised by your database driver library. Otherwise an IntegrityError will be raised by your database driver library.
#### create() #### remove
Removal of the related model one by one.
Removes also the relation in the database.
```python
await news.posts.remove(post)
```
#### clear
Removal of all related models in one call.
Removes also the relation in the database.
```python
await news.posts.clear()
```
#### QuerysetProxy
Reverse relation exposes QuerysetProxy API that allows you to query related model like you would issue a normal Query.
To read which methods of QuerySet are available read below [querysetproxy][querysetproxy]
### QuerySetProxy
When access directly the related `ManyToMany` field as well as `ReverseForeignKey` returns the list of related models.
But at the same time it exposes subset of QuerySet API, so you can filter, create, select related etc related models directly from parent model.
!!!note
By default exposed QuerySet is already filtered to return only `Models` related to parent `Model`.
So if you issue `post.categories.all()` you will get all categories related to that post, not all in table.
!!!note
Note that when accessing QuerySet API methods through QuerysetProxy you don't
need to use `objects` attribute like in normal queries.
So note that it's `post.categories.all()` and **not** `post.categories.objects.all()`.
To learn more about available QuerySet methods visit [queries][queries]
!!!warning
Querying related models from ManyToMany cleans list of related models loaded on parent model:
Example: `post.categories.first()` will set post.categories to list of 1 related model -> the one returned by first()
Example 2: if post has 4 categories so `len(post.categories) == 4` calling `post.categories.limit(2).all()`
-> will load only 2 children and now `assert len(post.categories) == 2`
This happens for all QuerysetProxy methods returning data: `get`, `all` and `first` and in `get_or_create` if model already exists.
Note that value returned by `create` or created in `get_or_create` and `update_or_create`
if model does not exist will be added to relation list (not clearing it).
#### get
`get(**kwargs): -> Model`
To grab just one of related models filtered by name you can use `get(**kwargs)` method.
```python
# grab one category
assert news == await post.categories.get(name="News")
# note that method returns the category so you can grab this value
# but it also modifies list of related models in place
# so regardless of what was previously loaded on parent model
# now it has only one value -> just loaded with get() call
assert len(post.categories) == 1
assert post.categories[0] == news
```
!!!tip
Read more in queries documentation [get][get]
#### all
`all(**kwargs) -> List[Optional["Model"]]`
To get a list of related models use `all()` method.
Note that you can filter the queryset, select related, exclude fields etc. like in normal query.
```python
# with all Queryset methods - filtering, selecting columns, counting etc.
await news.posts.filter(title__contains="M2M").all()
await Category.objects.filter(posts__author=guido).get()
# columns models of many to many relation can be prefetched
news_posts = await news.posts.select_related("author").all()
assert news_posts[0].author == guido
```
!!!tip
Read more in queries documentation [all][all]
#### create
`create(**kwargs): -> Model`
Create related `Model` directly from parent `Model`. Create related `Model` directly from parent `Model`.
@ -134,64 +314,117 @@ assert len(await post.categories.all()) == 2
# newly created instance already have relation persisted in the database # newly created instance already have relation persisted in the database
``` ```
!!!note !!!tip
Note that when accessing QuerySet API methods through ManyToMany relation you don't Read more in queries documentation [create][create]
need to use objects attribute like in normal queries.
To learn more about available QuerySet methods visit [queries][queries]
#### remove() #### get_or_create
Removal of the related model one by one. `get_or_create(**kwargs) -> Model`
Removes also the relation in the database.
```python
await news.posts.remove(post)
```
#### clear()
Removal all related models in one call.
Removes also the relation in the database.
```python
await news.posts.clear()
```
#### Other queryset methods
When access directly the related `ManyToMany` field returns the list of related models.
But at the same time it exposes full QuerySet API, so you can filter, create, select related etc.
```python
# Many to many relation exposes a list of columns models
# and an API of the Queryset:
assert news == await post.categories.get(name="News")
# with all Queryset methods - filtering, selecting columns, counting etc.
await news.posts.filter(title__contains="M2M").all()
await Category.objects.filter(posts__author=guido).get()
# columns models of many to many relation can be prefetched
news_posts = await news.posts.select_related("author").all()
assert news_posts[0].author == guido
```
Currently supported methods are:
!!!tip !!!tip
To learn more about available QuerySet methods visit [queries][queries] Read more in queries documentation [get_or_create][get_or_create]
#### update_or_create
`update_or_create(**kwargs) -> Model`
!!!tip
Read more in queries documentation [update_or_create][update_or_create]
#### filter
`filter(**kwargs) -> QuerySet`
!!!tip
Read more in queries documentation [filter][filter]
#### exclude
`exclude(**kwargs) -> QuerySet`
!!!tip
Read more in queries documentation [exclude][exclude]
#### select_related
`select_related(related: Union[List, str]) -> QuerySet`
!!!tip
Read more in queries documentation [select_related][select_related]
#### prefetch_related
`prefetch_related(related: Union[List, str]) -> QuerySet`
!!!tip
Read more in queries documentation [prefetch_related][prefetch_related]
#### limit
`limit(limit_count: int) -> QuerySet`
!!!tip
Read more in queries documentation [limit][limit]
#### offset
`offset(offset: int) -> QuerySet`
!!!tip
Read more in queries documentation [offset][offset]
#### count
`count() -> int`
!!!tip
Read more in queries documentation [count][count]
#### exists
`exists() -> bool`
!!!tip
Read more in queries documentation [exists][exists]
#### fields
`fields(columns: Union[List, str, set, dict]) -> QuerySet`
!!!tip
Read more in queries documentation [fields][fields]
#### exclude_fields
`exclude_fields(columns: Union[List, str, set, dict]) -> QuerySet`
!!!tip
Read more in queries documentation [exclude_fields][exclude_fields]
#### order_by
`order_by(columns:Union[List, str]) -> QuerySet`
!!!tip
Read more in queries documentation [order_by][order_by]
##### get()
##### all()
##### filter()
##### select_related()
##### limit()
##### offset()
##### count()
##### exists()
[queries]: ./queries.md [queries]: ./queries.md
[querysetproxy]: ./relations.md#querysetproxy-methods
[get]: ./queries.md#get
[all]: ./queries.md#all
[create]: ./queries.md#create
[get_or_create]: ./queries.md#get_or_create
[update_or_create]: ./queries.md#update_or_create
[filter]: ./queries.md#filter
[exclude]: ./queries.md#exclude
[select_related]: ./queries.md#select_related
[prefetch_related]: ./queries.md#prefetch_related
[limit]: ./queries.md#limit
[offset]: ./queries.md#offset
[count]: ./queries.md#count
[exists]: ./queries.md#exists
[fields]: ./queries.md#fields
[exclude_fields]: ./queries.md#exclude_fields
[order_by]: ./queries.md#order_by

View File

@ -1,3 +1,15 @@
# 0.6.0
* **Breaking:** calling instance.load() when the instance row was deleted from db now raises ormar.NoMatch instead of ValueError
* **Breaking:** calling add and remove on ReverseForeignKey relation now updates the child model in db setting/removing fk column
* **Breaking:** ReverseForeignKey relation now exposes QuerySetProxy API like ManyToMany relation
* **Breaking:** querying related models from ManyToMany cleans list of related models loaded on parent model:
* Example: `post.categories.first()` will set post.categories to list of 1 related model -> the one returned by first()
* Example 2: if post has 4 categories so `len(post.categories) == 4` calling `post.categories.limit(2).all()` -> will load only 2 children and now `assert len(post.categories) == 2`
* Added `get_or_create`, `update_or_create`, `fields`, `exclude_fields`, `exclude`, `prefetch_related` and `order_by` to QuerySetProxy
so now you can use those methods directly from relation
* Update docs
# 0.5.5 # 0.5.5
* Fix for alembic autogenaration of migration `UUID` columns. It should just produce sqlalchemy `CHAR(32)` or `CHAR(36)` * Fix for alembic autogenaration of migration `UUID` columns. It should just produce sqlalchemy `CHAR(32)` or `CHAR(36)`

View File

@ -29,7 +29,7 @@ class Course(ormar.Model):
department: Optional[Department] = ormar.ForeignKey(Department) department: Optional[Department] = ormar.ForeignKey(Department)
department = Department(name="Science") department = await Department(name="Science").save()
course = Course(name="Math", completed=False, department=department) course = Course(name="Math", completed=False, department=department)
print(department.courses[0]) print(department.courses[0])

View File

@ -1,5 +1,5 @@
site_name: ormar site_name: ormar
site_description: An simple async ORM with fastapi in mind and pydantic validation. site_description: A simple async ORM with fastapi in mind and pydantic validation.
nav: nav:
- Overview: index.md - Overview: index.md
- Installation: install.md - Installation: install.md

View File

@ -30,7 +30,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.5.5" __version__ = "0.6.0"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -6,6 +6,10 @@ class ModelDefinitionError(AsyncOrmException):
pass pass
class ModelError(AsyncOrmException):
pass
class ModelNotSet(AsyncOrmException): class ModelNotSet(AsyncOrmException):
pass pass

View File

@ -15,7 +15,7 @@ from typing import (
import sqlalchemy import sqlalchemy
import ormar.queryset # noqa I100 import ormar.queryset # noqa I100
from ormar.exceptions import ModelPersistenceError from ormar.exceptions import ModelPersistenceError, NoMatch
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
@ -286,9 +286,7 @@ class Model(NewBaseModel):
expr = self.Meta.table.select().where(self.pk_column == self.pk) expr = self.Meta.table.select().where(self.pk_column == self.pk)
row = await self.Meta.database.fetch_one(expr) row = await self.Meta.database.fetch_one(expr)
if not row: # pragma nocover if not row: # pragma nocover
raise ValueError( raise NoMatch("Instance was deleted from database and cannot be refreshed")
"Instance was deleted from database and cannot be refreshed"
)
kwargs = dict(row) kwargs = dict(row)
kwargs = self.translate_aliases_to_columns(kwargs) kwargs = self.translate_aliases_to_columns(kwargs)
self.from_dict(kwargs) self.from_dict(kwargs)

View File

@ -7,6 +7,7 @@ from typing import (
Dict, Dict,
List, List,
Mapping, Mapping,
MutableSequence,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -22,6 +23,7 @@ import sqlalchemy
from pydantic import BaseModel from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import ModelError
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.excludable import Excludable from ormar.models.excludable import Excludable
@ -93,16 +95,21 @@ class NewBaseModel(
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register # build the models to set them and validate but don't register
new_kwargs = { try:
k: self._convert_json( new_kwargs = {
k, k: self._convert_json(
self.Meta.model_fields[k].expand_relationship( k,
v, self, to_register=False self.Meta.model_fields[k].expand_relationship(
), v, self, to_register=False
"dumps", ),
"dumps",
)
for k, v in kwargs.items()
}
except KeyError as e:
raise ModelError(
f"Unknown field '{e.args[0]}' for model {self.get_name(lower=False)}"
) )
for k, v in kwargs.items()
}
values, fields_set, validation_error = pydantic.validate_model( values, fields_set, validation_error = pydantic.validate_model(
self, new_kwargs # type: ignore self, new_kwargs # type: ignore
@ -249,7 +256,9 @@ class NewBaseModel(
@staticmethod @staticmethod
def _extract_nested_models_from_list( def _extract_nested_models_from_list(
models: List, include: Union[Set, Dict, None], exclude: Union[Set, Dict, None], models: MutableSequence,
include: Union[Set, Dict, None],
exclude: Union[Set, Dict, None],
) -> List: ) -> List:
result = [] result = []
for model in models: for model in models:
@ -282,7 +291,7 @@ class NewBaseModel(
if self.Meta.model_fields[field].virtual and nested: if self.Meta.model_fields[field].virtual and nested:
continue continue
nested_model = getattr(self, field) nested_model = getattr(self, field)
if isinstance(nested_model, list): if isinstance(nested_model, MutableSequence):
dict_instance[field] = self._extract_nested_models_from_list( dict_instance[field] = self._extract_nested_models_from_list(
models=nested_model, models=nested_model,
include=self._skip_ellipsis(include, field), include=self._skip_ellipsis(include, field),
@ -308,7 +317,7 @@ class NewBaseModel(
exclude_unset: bool = False, exclude_unset: bool = False,
exclude_defaults: bool = False, exclude_defaults: bool = False,
exclude_none: bool = False, exclude_none: bool = False,
nested: bool = False nested: bool = False,
) -> "DictStrAny": # noqa: A003' ) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict( dict_instance = super().dict(
include=include, include=include,

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, Union from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union
try: try:
from typing import Protocol from typing import Protocol
@ -6,14 +6,21 @@ except ImportError: # pragma: nocover
from typing_extensions import Protocol # type: ignore from typing_extensions import Protocol # type: ignore
if TYPE_CHECKING: # noqa: C901; #pragma nocover if TYPE_CHECKING: # noqa: C901; #pragma nocover
from ormar import QuerySet, Model from ormar import Model
from ormar.relations.querysetproxy import QuerysetProxy
class QuerySetProtocol(Protocol): # pragma: nocover class QuerySetProtocol(Protocol): # pragma: nocover
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003, A001 def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
... ...
def select_related(self, related: Union[List, str]) -> "QuerySet": def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
...
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
...
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
... ...
async def exists(self) -> bool: async def exists(self) -> bool:
@ -25,10 +32,10 @@ class QuerySetProtocol(Protocol): # pragma: nocover
async def clear(self) -> int: async def clear(self) -> int:
... ...
def limit(self, limit_count: int) -> "QuerySet": def limit(self, limit_count: int) -> "QuerysetProxy":
... ...
def offset(self, offset: int) -> "QuerySet": def offset(self, offset: int) -> "QuerysetProxy":
... ...
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "Model":
@ -44,3 +51,18 @@ class QuerySetProtocol(Protocol): # pragma: nocover
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> "Model":
... ...
async def get_or_create(self, **kwargs: Any) -> "Model":
...
async def update_or_create(self, **kwargs: Any) -> "Model":
...
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
...
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
...
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
...

View File

@ -280,7 +280,9 @@ class QuerySet:
return await self.database.fetch_val(expr) return await self.database.fetch_val(expr)
async def update(self, each: bool = False, **kwargs: Any) -> int: async def update(self, each: bool = False, **kwargs: Any) -> int:
self_fields = self.model.extract_db_own_fields() self_fields = self.model.extract_db_own_fields().union(
self.model.extract_related_names()
)
updates = {k: v for k, v in kwargs.items() if k in self_fields} updates = {k: v for k, v in kwargs.items() if k in self_fields}
updates = self.model.translate_columns_to_aliases(updates) updates = self.model.translate_columns_to_aliases(updates)
if not each and not self.filter_clauses: if not each and not self.filter_clauses:

View File

@ -1,4 +1,15 @@
from typing import Any, List, Optional, Sequence, TYPE_CHECKING, TypeVar, Union from typing import (
Any,
Dict,
List,
MutableSequence,
Optional,
Sequence,
Set,
TYPE_CHECKING,
TypeVar,
Union,
)
import ormar import ormar
@ -6,6 +17,7 @@ if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation from ormar.relations import Relation
from ormar.models import Model from ormar.models import Model
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
from ormar import RelationType
T = TypeVar("T", bound=Model) T = TypeVar("T", bound=Model)
@ -14,9 +26,17 @@ class QuerysetProxy(ormar.QuerySetProtocol):
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
relation: "Relation" relation: "Relation"
def __init__(self, relation: "Relation") -> None: def __init__(
self, relation: "Relation", type_: "RelationType", qryset: "QuerySet" = None
) -> None:
self.relation: Relation = relation self.relation: Relation = relation
self._queryset: Optional["QuerySet"] = None self._queryset: Optional["QuerySet"] = qryset
self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner
self.related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
self.owner_pk_value = self._owner.pk
@property @property
def queryset(self) -> "QuerySet": def queryset(self) -> "QuerySet":
@ -30,7 +50,7 @@ class QuerysetProxy(ormar.QuerySetProtocol):
def _assign_child_to_parent(self, child: Optional["T"]) -> None: def _assign_child_to_parent(self, child: Optional["T"]) -> None:
if child: if child:
owner = self.relation._owner owner = self._owner
rel_name = owner.resolve_relation_name(owner, child) rel_name = owner.resolve_relation_name(owner, child)
setattr(owner, rel_name, child) setattr(owner, rel_name, child)
@ -42,62 +62,122 @@ class QuerysetProxy(ormar.QuerySetProtocol):
assert isinstance(child, ormar.Model) assert isinstance(child, ormar.Model)
self._assign_child_to_parent(child) self._assign_child_to_parent(child)
def _clean_items_on_load(self) -> None:
if isinstance(self.relation.related_models, MutableSequence):
for item in self.relation.related_models[:]:
self.relation.remove(item)
async def create_through_instance(self, child: "T") -> None: async def create_through_instance(self, child: "T") -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through) queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name() owner_column = self._owner.get_name()
child_column = child.get_name() child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child} kwargs = {owner_column: self._owner, child_column: child}
await queryset.create(**kwargs) await queryset.create(**kwargs)
async def delete_through_instance(self, child: "T") -> None: async def delete_through_instance(self, child: "T") -> None:
queryset = ormar.QuerySet(model_cls=self.relation.through) queryset = ormar.QuerySet(model_cls=self.relation.through)
owner_column = self.relation._owner.get_name() owner_column = self._owner.get_name()
child_column = child.get_name() child_column = child.get_name()
kwargs = {owner_column: self.relation._owner, child_column: child} kwargs = {owner_column: self._owner, child_column: child}
link_instance = await queryset.filter(**kwargs).get() # type: ignore link_instance = await queryset.filter(**kwargs).get() # type: ignore
await link_instance.delete() await link_instance.delete()
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
return self.queryset.filter(**kwargs)
def select_related(self, related: Union[List, str]) -> "QuerySet":
return self.queryset.select_related(related)
async def exists(self) -> bool: async def exists(self) -> bool:
return await self.queryset.exists() return await self.queryset.exists()
async def count(self) -> int: async def count(self) -> int:
return await self.queryset.count() return await self.queryset.count()
async def clear(self) -> int: async def clear(self, keep_reversed: bool = True) -> int:
queryset = ormar.QuerySet(model_cls=self.relation.through) if self.type_ == ormar.RelationType.MULTIPLE:
owner_column = self.relation._owner.get_name() queryset = ormar.QuerySet(model_cls=self.relation.through)
kwargs = {owner_column: self.relation._owner} owner_column = self._owner.get_name()
else:
queryset = ormar.QuerySet(model_cls=self.relation.to)
owner_column = self.related_field.name
kwargs = {owner_column: self._owner}
self._clean_items_on_load()
if keep_reversed and self.type_ == ormar.RelationType.REVERSE:
update_kwrgs = {f"{owner_column}": None}
return await queryset.filter(_exclude=False, **kwargs).update(
each=False, **update_kwrgs
)
return await queryset.delete(**kwargs) # type: ignore return await queryset.delete(**kwargs) # type: ignore
def limit(self, limit_count: int) -> "QuerySet":
return self.queryset.limit(limit_count)
def offset(self, offset: int) -> "QuerySet":
return self.queryset.offset(offset)
async def first(self, **kwargs: Any) -> "Model": async def first(self, **kwargs: Any) -> "Model":
first = await self.queryset.first(**kwargs) first = await self.queryset.first(**kwargs)
self._clean_items_on_load()
self._register_related(first) self._register_related(first)
return first return first
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "Model":
get = await self.queryset.get(**kwargs) get = await self.queryset.get(**kwargs)
self._clean_items_on_load()
self._register_related(get) self._register_related(get)
return get return get
async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003 async def all(self, **kwargs: Any) -> Sequence[Optional["Model"]]: # noqa: A003
all_items = await self.queryset.all(**kwargs) all_items = await self.queryset.all(**kwargs)
self._clean_items_on_load()
self._register_related(all_items) self._register_related(all_items)
return all_items return all_items
async def create(self, **kwargs: Any) -> "Model": async def create(self, **kwargs: Any) -> "Model":
create = await self.queryset.create(**kwargs) if self.type_ == ormar.RelationType.REVERSE:
self._register_related(create) kwargs[self.related_field.name] = self._owner
await self.create_through_instance(create) created = await self.queryset.create(**kwargs)
return create self._register_related(created)
if self.type_ == ormar.RelationType.MULTIPLE:
await self.create_through_instance(created)
return created
async def get_or_create(self, **kwargs: Any) -> "Model":
try:
return await self.get(**kwargs)
except ormar.NoMatch:
return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model":
pk_name = self.queryset.model_meta.pkname
if "pk" in kwargs:
kwargs[pk_name] = kwargs.pop("pk")
if pk_name not in kwargs or kwargs.get(pk_name) is None:
return await self.create(**kwargs)
model = await self.queryset.get(pk=kwargs[pk_name])
return await model.update(**kwargs)
def filter(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
queryset = self.queryset.filter(**kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def exclude(self, **kwargs: Any) -> "QuerysetProxy": # noqa: A003, A001
queryset = self.queryset.exclude(**kwargs)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def select_related(self, related: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.select_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def prefetch_related(self, related: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.prefetch_related(related)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def limit(self, limit_count: int) -> "QuerysetProxy":
queryset = self.queryset.limit(limit_count)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def offset(self, offset: int) -> "QuerysetProxy":
queryset = self.queryset.offset(offset)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
queryset = self.queryset.fields(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerysetProxy":
queryset = self.queryset.exclude_fields(columns=columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)
def order_by(self, columns: Union[List, str]) -> "QuerysetProxy":
queryset = self.queryset.order_by(columns)
return self.__class__(relation=self.relation, type_=self.type_, qryset=queryset)

View File

@ -34,7 +34,7 @@ class Relation:
self.to: Type["T"] = to self.to: Type["T"] = to
self.through: Optional[Type["T"]] = through self.through: Optional[Type["T"]] = through
self.related_models: Optional[Union[RelationProxy, "T"]] = ( self.related_models: Optional[Union[RelationProxy, "T"]] = (
RelationProxy(relation=self) RelationProxy(relation=self, type_=type_)
if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE)
else None else None
) )

View File

@ -65,8 +65,6 @@ class RelationsManager:
parent_relation = parent._orm._get(child_name) parent_relation = parent._orm._get(child_name)
if parent_relation: if parent_relation:
# print('missing', child_name)
# parent_relation = register_missing_relation(parent, child, child_name)
parent_relation.add(child) # type: ignore parent_relation.add(child) # type: ignore
child_relation = child._orm._get(to_name) child_relation = child._orm._get(to_name)

View File

@ -5,17 +5,18 @@ from ormar.exceptions import NoMatch, RelationshipInstanceError
from ormar.relations.querysetproxy import QuerysetProxy from ormar.relations.querysetproxy import QuerysetProxy
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model, RelationType
from ormar.relations import Relation from ormar.relations import Relation
from ormar.queryset import QuerySet from ormar.queryset import QuerySet
class RelationProxy(list): class RelationProxy(list):
def __init__(self, relation: "Relation") -> None: def __init__(self, relation: "Relation", type_: "RelationType") -> None:
super(RelationProxy, self).__init__() super().__init__()
self.relation: Relation = relation self.relation: "Relation" = relation
self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner
self.queryset_proxy = QuerysetProxy(relation=self.relation) self.queryset_proxy = QuerysetProxy(relation=self.relation, type_=type_)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ["count", "clear"]: if item in ["count", "clear"]:
@ -37,23 +38,30 @@ class RelationProxy(list):
and self.queryset_proxy.queryset is not None and self.queryset_proxy.queryset is not None
) )
def _set_queryset(self) -> "QuerySet": def _check_if_model_saved(self) -> None:
owner_table = self.relation._owner.Meta.tablename pk_value = self._owner.pk
pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname)
pk_value = self.relation._owner.pk
if not pk_value: if not pk_value:
raise RelationshipInstanceError( raise RelationshipInstanceError(
"You cannot query many to many relationship on unsaved model." "You cannot query relationships from unsaved model."
) )
kwargs = {f"{owner_table}__{pkname}": pk_value}
def _set_queryset(self) -> "QuerySet":
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
self._check_if_model_saved()
kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk}
queryset = ( queryset = (
ormar.QuerySet(model_cls=self.relation.to) ormar.QuerySet(model_cls=self.relation.to)
.select_related(owner_table) .select_related(related_field.name)
.filter(**kwargs) .filter(**kwargs)
) )
return queryset return queryset
async def remove(self, item: "Model") -> None: # type: ignore async def remove( # type: ignore
self, item: "Model", keep_reversed: bool = True
) -> None:
if item not in self: if item not in self:
raise NoMatch( raise NoMatch(
f"Object {self._owner.get_name()} has no " f"Object {self._owner.get_name()} has no "
@ -67,14 +75,25 @@ class RelationProxy(list):
f"{self._owner.get_name()} does not have relation {rel_name}" f"{self._owner.get_name()} does not have relation {rel_name}"
) )
relation.remove(self._owner) relation.remove(self._owner)
if self.relation._type == ormar.RelationType.MULTIPLE: self.relation.remove(item)
if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.delete_through_instance(item) await self.queryset_proxy.delete_through_instance(item)
else:
def append(self, item: "Model") -> None: if keep_reversed:
super().append(item) setattr(item, rel_name, None)
await item.update()
else:
await item.delete()
async def add(self, item: "Model") -> None: async def add(self, item: "Model") -> None:
if self.relation._type == ormar.RelationType.MULTIPLE: if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item) await self.queryset_proxy.create_through_instance(item)
rel_name = item.resolve_relation_name(item, self._owner) rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner) setattr(item, rel_name, self._owner)
else:
self._check_if_model_saved()
related_field = self._owner.resolve_relation_field(
self.relation.to, self._owner
)
setattr(item, related_field.name, self._owner)
await item.update()

View File

@ -9,7 +9,7 @@ import pytest
import sqlalchemy import sqlalchemy
import ormar import ormar
from ormar.exceptions import QueryDefinitionError, NoMatch from ormar.exceptions import QueryDefinitionError, NoMatch, ModelError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -117,6 +117,11 @@ def test_model_class():
assert isinstance(User.Meta.table, sqlalchemy.Table) assert isinstance(User.Meta.table, sqlalchemy.Table)
def test_wrong_field_name():
with pytest.raises(ModelError):
User(non_existing_pk=1)
def test_model_pk(): def test_model_pk():
user = User(pk=1) user = User(pk=1)
assert user.pk == 1 assert user.pk == 1

View File

@ -0,0 +1,182 @@
import asyncio
from typing import List, Optional, Union
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Subject(ormar.Model):
class Meta:
tablename = "subjects"
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=80)
class Author(ormar.Model):
class Meta:
tablename = "authors"
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
first_name: str = ormar.String(max_length=80)
last_name: str = ormar.String(max_length=80)
class Category(ormar.Model):
class Meta:
tablename = "categories"
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40)
sort_order: int = ormar.Integer(nullable=True)
subject: Optional[Subject] = ormar.ForeignKey(Subject)
class PostCategory(ormar.Model):
class Meta:
tablename = "posts_categories"
database = database
metadata = metadata
class Post(ormar.Model):
class Meta:
tablename = "posts"
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany(
Category, through=PostCategory
)
author: Optional[Author] = ormar.ForeignKey(Author)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_queryset_methods():
async with database:
async with database.transaction(force_rollback=True):
guido = await Author.objects.create(
first_name="Guido", last_name="Van Rossum"
)
subject = await Subject(name="Random").save()
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(
name="News", sort_order=1, subject=subject
)
breaking = await Category.objects.create(
name="Breaking", sort_order=3, subject=subject
)
# Add a category to a post.
await post.categories.add(news)
await post.categories.add(breaking)
category = await post.categories.get_or_create(name="News")
assert category == news
assert len(post.categories) == 1
category = await post.categories.get_or_create(name="Breaking News")
assert category != breaking
assert category.pk is not None
assert len(post.categories) == 2
await post.categories.update_or_create(pk=category.pk, name="Urgent News")
assert len(post.categories) == 2
cat = await post.categories.get_or_create(name="Urgent News")
assert cat.pk == category.pk
assert len(post.categories) == 1
await post.categories.remove(cat)
await cat.delete()
assert len(post.categories) == 0
category = await post.categories.update_or_create(
name="Weather News", sort_order=2, subject=subject
)
assert category.pk is not None
assert category.posts[0] == post
assert len(post.categories) == 1
categories = await post.categories.all()
assert len(categories) == 3 == len(post.categories)
assert await post.categories.exists()
assert 3 == await post.categories.count()
categories = await post.categories.limit(2).all()
assert len(categories) == 2 == len(post.categories)
categories2 = await post.categories.limit(2).offset(1).all()
assert len(categories2) == 2 == len(post.categories)
assert categories != categories2
categories = await post.categories.order_by("-sort_order").all()
assert len(categories) == 3 == len(post.categories)
assert post.categories[2].name == "News"
assert post.categories[0].name == "Breaking"
categories = await post.categories.exclude(name__icontains="news").all()
assert len(categories) == 1 == len(post.categories)
assert post.categories[0].name == "Breaking"
categories = (
await post.categories.filter(name__icontains="news")
.order_by("-name")
.all()
)
assert len(categories) == 2 == len(post.categories)
assert post.categories[0].name == "Weather News"
assert post.categories[1].name == "News"
categories = await post.categories.fields("name").all()
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.sort_order is None
categories = await post.categories.exclude_fields("sort_order").all()
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.sort_order is None
assert cat.subject.name is None
categories = await post.categories.select_related("subject").all()
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.subject.name is not None
categories = await post.categories.prefetch_related("subject").all()
assert len(categories) == 3 == len(post.categories)
for cat in post.categories:
assert cat.subject.name is not None

View File

@ -0,0 +1,258 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from ormar import NoMatch
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Album(ormar.Model):
class Meta:
tablename = "albums"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
is_best_seller: bool = ormar.Boolean(default=False)
class Writer(ormar.Model):
class Meta:
tablename = "writers"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Track(ormar.Model):
class Meta:
tablename = "tracks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album)
title: str = ormar.String(max_length=100)
position: int = ormar.Integer()
play_count: int = ormar.Integer(nullable=True)
written_by: Optional[Writer] = ormar.ForeignKey(Writer)
async def get_sample_data():
album = await Album(name="Malibu").save()
writer1 = await Writer.objects.create(name="John")
writer2 = await Writer.objects.create(name="Sue")
track1 = await Track(
album=album, title="The Bird", position=1, play_count=30, written_by=writer1
).save()
track2 = await Track(
album=album,
title="Heart don't stand a chance",
position=2,
play_count=20,
written_by=writer2,
).save()
tracks3 = await Track(
album=album, title="The Waters", position=3, play_count=10, written_by=writer1
).save()
return album, [track1, track2, tracks3]
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_quering_by_reverse_fk():
async with database:
async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
track1 = sample_data[1][0]
album = await Album.objects.first()
assert await album.tracks.exists()
assert await album.tracks.count() == 3
track = await album.tracks.get_or_create(
title="The Bird", position=1, play_count=30
)
assert track == track1
assert len(album.tracks) == 1
track = await album.tracks.get_or_create(
title="The Bird2", position=4, play_count=5
)
assert track != track1
assert track.pk is not None
assert len(album.tracks) == 2
await album.tracks.update_or_create(pk=track.pk, play_count=50)
assert len(album.tracks) == 2
track = await album.tracks.get_or_create(title="The Bird2")
assert track.play_count == 50
assert len(album.tracks) == 1
await album.tracks.remove(track)
assert track.album is None
await track.delete()
assert len(album.tracks) == 0
track6 = await album.tracks.update_or_create(
title="The Bird3", position=4, play_count=5
)
assert track6.pk is not None
assert track6.play_count == 5
assert len(album.tracks) == 1
await album.tracks.remove(track6)
assert track6.album is None
await track6.delete()
assert len(album.tracks) == 0
@pytest.mark.asyncio
async def test_getting():
async with database:
async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
album = sample_data[0]
track1 = await album.tracks.fields(["album", "title", "position"]).get(
title="The Bird"
)
track2 = await album.tracks.exclude_fields("play_count").get(
title="The Bird"
)
for track in [track1, track2]:
assert track.title == "The Bird"
assert track.album == album
assert track.play_count is None
assert len(album.tracks) == 1
tracks = await album.tracks.all()
assert len(tracks) == 3
assert len(album.tracks) == 3
tracks = await album.tracks.order_by("play_count").all()
assert len(tracks) == 3
assert tracks[0].title == "The Waters"
assert tracks[2].title == "The Bird"
assert len(album.tracks) == 3
track = await album.tracks.create(
title="The Bird Fly Away", position=4, play_count=10
)
assert track.title == "The Bird Fly Away"
assert track.position == 4
assert track.album == album
assert len(album.tracks) == 4
tracks = await album.tracks.all()
assert len(tracks) == 4
tracks = await album.tracks.limit(2).all()
assert len(tracks) == 2
tracks2 = await album.tracks.limit(2).offset(2).all()
assert len(tracks2) == 2
assert tracks != tracks2
tracks3 = await album.tracks.filter(play_count__lt=15).all()
assert len(tracks3) == 2
tracks4 = await album.tracks.exclude(play_count__lt=15).all()
assert len(tracks4) == 2
assert tracks3 != tracks4
assert len(album.tracks) == 2
await album.tracks.clear()
tracks = await album.tracks.all()
assert len(tracks) == 0
assert len(album.tracks) == 0
still_tracks = await Track.objects.all()
assert len(still_tracks) == 4
for track in still_tracks:
assert track.album is None
@pytest.mark.asyncio
async def test_cleaning_related():
async with database:
async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
album = sample_data[0]
await album.tracks.clear(keep_reversed=False)
tracks = await album.tracks.all()
assert len(tracks) == 0
assert len(album.tracks) == 0
no_tracks = await Track.objects.all()
assert len(no_tracks) == 0
@pytest.mark.asyncio
async def test_loading_related():
async with database:
async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
album = sample_data[0]
tracks = await album.tracks.select_related("written_by").all()
assert len(tracks) == 3
assert len(album.tracks) == 3
for track in tracks:
assert track.written_by is not None
tracks = await album.tracks.prefetch_related("written_by").all()
assert len(tracks) == 3
assert len(album.tracks) == 3
for track in tracks:
assert track.written_by is not None
@pytest.mark.asyncio
async def test_adding_removing():
async with database:
async with database.transaction(force_rollback=True):
sample_data = await get_sample_data()
album = sample_data[0]
track_new = await Track(title="Rainbow", position=5, play_count=300).save()
await album.tracks.add(track_new)
assert track_new.album == album
assert len(album.tracks) == 4
track_check = await Track.objects.get(title="Rainbow")
assert track_check.album == album
await album.tracks.remove(track_new)
assert track_new.album is None
assert len(album.tracks) == 3
track1 = album.tracks[0]
await album.tracks.remove(track1, keep_reversed=False)
with pytest.raises(NoMatch):
await track1.load()
track_test = await Track.objects.get(title="Rainbow")
assert track_test.album is None