Merge pull request #140 from collerek/fix_queryset_alias

Bug fixes and optimization
This commit is contained in:
collerek
2021-04-06 14:25:24 +02:00
committed by GitHub
104 changed files with 1065 additions and 174 deletions

View File

@ -14,6 +14,7 @@ jobs:
tests: tests:
name: "Python ${{ matrix.python-version }}" name: "Python ${{ matrix.python-version }}"
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: github.event_name == 'push' || github.event.pull_request.head.repo.full_name != 'collerek/ormar'
strategy: strategy:
matrix: matrix:
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.6, 3.7, 3.8, 3.9]

View File

@ -10,6 +10,13 @@ Each model instance have a set of methods to `save`, `update` or `load` itself.
Available methods are described below. Available methods are described below.
## `pydantic` methods
Note that each `ormar.Model` is also a `pydantic.BaseModel`, so all `pydantic` methods are also available on a model,
especially `dict()` and `json()` methods that can also accept `exclude`, `include` and other parameters.
To read more check [pydantic][pydantic] documentation
## load ## load
By default when you query a table without prefetching related models, the ormar will still construct By default when you query a table without prefetching related models, the ormar will still construct
@ -81,7 +88,7 @@ await track.save() # will raise integrity error as pk is populated
## update ## update
`update(**kwargs) -> self` `update(_columns: List[str] = None, **kwargs) -> self`
You can update models by using `QuerySet.update()` method or by updating your model attributes (fields) and calling `update()` method. You can update models by using `QuerySet.update()` method or by updating your model attributes (fields) and calling `update()` method.
@ -94,6 +101,42 @@ track = await Track.objects.get(name='The Bird')
await track.update(name='The Bird Strikes Again') await track.update(name='The Bird Strikes Again')
``` ```
To update only selected columns from model into the database provide a list of columns that should be updated to `_columns` argument.
In example:
```python
class Movie(ormar.Model):
class Meta:
tablename = "movies"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="title")
year: int = ormar.Integer()
profit: float = ormar.Float()
terminator = await Movie(name='Terminator', year=1984, profit=0.078).save()
terminator.name = "Terminator 2"
terminator.year = 1991
terminator.profit = 0.520
# update only name
await terminator.update(_columns=["name"])
# note that terminator instance was not reloaded so
assert terminator.year == 1991
# but once you load the data from db you see it was not updated
await terminator.load()
assert terminator.year == 1984
```
!!!warning
Note that `update()` does not refresh the instance of the Model, so if you change more columns than you pass in `_columns` list your Model instance will have different values than the database!
## upsert ## upsert
`upsert(**kwargs) -> self` `upsert(**kwargs) -> self`
@ -127,7 +170,7 @@ await track.delete() # will delete the model from database
## save_related ## save_related
`save_related(follow: bool = False) -> None` `save_related(follow: bool = False, save_all: bool = False, exclude=Optional[Union[Set, Dict]]) -> None`
Method goes through all relations of the `Model` on which the method is called, Method goes through all relations of the `Model` on which the method is called,
and calls `upsert()` method on each model that is **not** saved. and calls `upsert()` method on each model that is **not** saved.
@ -138,17 +181,28 @@ By default the `save_related` method saved only models that are directly related
But you can specify the `follow=True` parameter to traverse through nested models and save all of them in the relation tree. But you can specify the `follow=True` parameter to traverse through nested models and save all of them in the relation tree.
By default save_related saves only model that has not `saved` status, meaning that they were modified in current scope.
If you want to force saving all of the related methods use `save_all=True` flag, which will upsert all related models, regardless of their save status.
If you want to skip saving some of the relations you can pass `exclude` parameter.
`Exclude` can be a set of own model relations,
or it can be a dictionary that can also contain nested items.
!!!note
Note that `exclude` parameter in `save_related` accepts only relation fields names, so
if you pass any other fields they will be saved anyway
!!!note
To read more about the structure of possible values passed to `exclude` check `Queryset.fields` method documentation.
!!!warning !!!warning
To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models, To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models,
and won't perform nested `save_related` on Models that were already visited. and won't perform nested `save_related` on Models that were already visited.
So if you have a diamond or circular relations types you need to perform the updates in a manual way. So if you have a diamond or circular relations types you need to perform the updates in a manual way.
```python
# in example like this the second Street (coming from City) won't be save_related, so ZipCode won't be updated
Street -> District -> City -> Street -> ZipCode
```
[fields]: ../fields.md [fields]: ../fields.md
[relations]: ../relations/index.md [relations]: ../relations/index.md
[queries]: ../queries/index.md [queries]: ../queries/index.md

View File

@ -1,3 +1,40 @@
# 0.10.2
## ✨ Features
* `Model.save_related(follow=False)` now accept also two additional arguments: `Model.save_related(follow=False, save_all=False, exclude=None)`.
* `save_all:bool` -> By default (so with `save_all=False`) `ormar` only upserts models that are not saved (so new or updated ones),
with `save_all=True` all related models are saved, regardless of `saved` status, which might be useful if updated
models comes from api call, so are not changed in the backend.
* `exclude: Union[Set, Dict, None]` -> set/dict of relations to exclude from save, those relation won't be saved even with `follow=True` and `save_all=True`.
To exclude nested relations pass a nested dictionary like: `exclude={"child":{"sub_child": {"exclude_sub_child_realtion"}}}`. The allowed values follow
the `fields/exclude_fields` (from `QuerySet`) methods schema so when in doubt you can refer to docs in queries -> selecting subset of fields -> fields.
* `Model.update()` method now accepts `_columns: List[str] = None` parameter, that accepts list of column names to update. If passed only those columns will be updated in database.
Note that `update()` does not refresh the instance of the Model, so if you change more columns than you pass in `_columns` list your Model instance will have different values than the database!
* `Model.dict()` method previously included only directly related models or nested models if they were not nullable and not virtual,
now all related models not previously visited without loops are included in `dict()`. This should be not breaking
as just more data will be dumped to dict, but it should not be missing.
* `QuerySet.delete(each=False, **kwargs)` previously required that you either pass a `filter` (by `**kwargs` or as a separate `filter()` call) or set `each=True` now also accepts
`exclude()` calls that generates NOT filter. So either `each=True` needs to be set to delete whole table or at least one of `filter/exclude` clauses.
* Same thing applies to `QuerySet.update(each=False, **kwargs)` which also previously required that you either pass a `filter` (by `**kwargs` or as a separate `filter()` call) or set `each=True` now also accepts
`exclude()` calls that generates NOT filter. So either `each=True` needs to be set to update whole table or at least one of `filter/exclude` clauses.
* Same thing applies to `QuerysetProxy.update(each=False, **kwargs)` which also previously required that you either pass a `filter` (by `**kwargs` or as a separate `filter()` call) or set `each=True` now also accepts
`exclude()` calls that generates NOT filter. So either `each=True` needs to be set to update whole table or at least one of `filter/exclude` clauses.
## 🐛 Fixes
* Fix improper relation field resolution in `QuerysetProxy` if fk column has different database alias.
* Fix hitting recursion error with very complicated models structure with loops when calling `dict()`.
* Fix bug when two non-relation fields were merged (appended) in query result when they were not relation fields (i.e. JSON)
* Fix bug when during translation to dict from list the same relation name is used in chain but leads to different models
* Fix bug when bulk_create would try to save also `property_field` decorated methods and `pydantic` fields
* Fix wrong merging of deeply nested chain of reversed relations
## 💬 Other
* Performance optimizations
* Split tests into packages based on tested area
# 0.10.1 # 0.10.1
## Features ## Features

View File

@ -59,6 +59,7 @@ nav:
- Model Table Proxy: api/models/model-table-proxy.md - Model Table Proxy: api/models/model-table-proxy.md
- Model Metaclass: api/models/model-metaclass.md - Model Metaclass: api/models/model-metaclass.md
- Excludable Items: api/models/excludable-items.md - Excludable Items: api/models/excludable-items.md
- Traversible: api/models/traversible.md
- Fields: - Fields:
- Base Field: api/fields/base-field.md - Base Field: api/fields/base-field.md
- Model Fields: api/fields/model-fields.md - Model Fields: api/fields/model-fields.md

View File

@ -75,7 +75,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.10.1" __version__ = "0.10.2"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -138,10 +138,8 @@ class ExcludableMixin(RelationMixin):
return columns return columns
@classmethod @classmethod
def _update_excluded_with_related_not_required( def _update_excluded_with_related(
cls, cls, exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None],
exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None],
nested: bool = False,
) -> Union[Set, Dict]: ) -> Union[Set, Dict]:
""" """
Used during generation of the dict(). Used during generation of the dict().
@ -159,8 +157,9 @@ class ExcludableMixin(RelationMixin):
:rtype: Union[Set, Dict] :rtype: Union[Set, Dict]
""" """
exclude = exclude or {} exclude = exclude or {}
related_set = cls._exclude_related_names_not_required(nested=nested) related_set = cls.extract_related_names()
if isinstance(exclude, set): if isinstance(exclude, set):
exclude = {s for s in exclude}
exclude.union(related_set) exclude.union(related_set)
else: else:
related_dict = translate_list_to_dict(related_set) related_dict = translate_list_to_dict(related_set)

View File

@ -1,7 +1,8 @@
from collections import OrderedDict from collections import OrderedDict
from typing import List, TYPE_CHECKING from typing import Dict, List, Optional, TYPE_CHECKING, cast
import ormar import ormar
from ormar.queryset.utils import translate_list_to_dict
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -46,13 +47,17 @@ class MergeModelMixin:
return merged_rows return merged_rows
@classmethod @classmethod
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": def merge_two_instances(
cls, one: "Model", other: "Model", relation_map: Dict = None
) -> "Model":
""" """
Merges current (other) Model and previous one (one) and returns the current Merges current (other) Model and previous one (one) and returns the current
Model instance with data merged from previous one. Model instance with data merged from previous one.
If needed it's calling itself recurrently and merges also children models. If needed it's calling itself recurrently and merges also children models.
:param relation_map: map of models relations to follow
:type relation_map: Dict
:param one: previous model instance :param one: previous model instance
:type one: Model :type one: Model
:param other: current model instance :param other: current model instance
@ -60,20 +65,80 @@ class MergeModelMixin:
:return: current Model instance with data merged from previous one. :return: current Model instance with data merged from previous one.
:rtype: Model :rtype: Model
""" """
for field in one.Meta.model_fields.keys(): relation_map = (
current_field = getattr(one, field) relation_map
if isinstance(current_field, list) and not isinstance( if relation_map is not None
current_field, ormar.Model else translate_list_to_dict(one._iterate_related_models())
): )
setattr(other, field, current_field + getattr(other, field)) for field_name in relation_map:
current_field = getattr(one, field_name)
other_value = getattr(other, field_name, [])
if isinstance(current_field, list):
value_to_set = cls._merge_items_lists(
field_name=field_name,
current_field=current_field,
other_value=other_value,
relation_map=relation_map,
)
setattr(other, field_name, value_to_set)
elif ( elif (
isinstance(current_field, ormar.Model) isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk and current_field.pk == other_value.pk
): ):
setattr( setattr(
other, other,
field, field_name,
cls.merge_two_instances(current_field, getattr(other, field)), cls.merge_two_instances(
current_field,
other_value,
relation_map=one._skip_ellipsis( # type: ignore
relation_map, field_name, default_return=dict()
),
),
) )
other.set_save_status(True) other.set_save_status(True)
return other return other
@classmethod
def _merge_items_lists(
cls,
field_name: str,
current_field: List,
other_value: List,
relation_map: Optional[Dict],
) -> List:
"""
Takes two list of nested models and process them going deeper
according with the map.
If model from one's list is in other -> they are merged with relations
to follow passed from map.
If one's model is not in other it's simply appended to the list.
:param field_name: name of the current relation field
:type field_name: str
:param current_field: list of nested models from one model
:type current_field: List[Model]
:param other_value: list of nested models from other model
:type other_value: List[Model]
:param relation_map: map of relations to follow
:type relation_map: Dict
:return: merged list of models
:rtype: List[Model]
"""
value_to_set = [x for x in other_value]
for cur_field in current_field:
if cur_field in other_value:
old_value = next((x for x in other_value if x == cur_field), None)
new_val = cls.merge_two_instances(
cur_field,
cast("Model", old_value),
relation_map=cur_field._skip_ellipsis( # type: ignore
relation_map, field_name, default_return=dict()
),
)
value_to_set = [x for x in value_to_set if x != cur_field] + [new_val]
else:
value_to_set.append(cur_field)
return value_to_set

View File

@ -4,11 +4,10 @@ from typing import (
Optional, Optional,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Type,
Union,
) )
from ormar import BaseField from ormar import BaseField
from ormar.models.traversible import NodeList
class RelationMixin: class RelationMixin:
@ -17,7 +16,7 @@ class RelationMixin:
""" """
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import ModelMeta, Model from ormar import ModelMeta
Meta: ModelMeta Meta: ModelMeta
_related_names: Optional[Set] _related_names: Optional[Set]
@ -112,84 +111,39 @@ class RelationMixin:
} }
return related_names return related_names
@classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
"""
Returns a set of non mandatory related models field names.
For a main model (not nested) only nullable related field names are returned,
for nested models all related models are returned.
:param nested: flag setting nested models (child of previous one, not main one)
:type nested: bool
:return: set of non mandatory related fields
:rtype: Set
"""
if nested:
return cls.extract_related_names()
related_names = cls.extract_related_names()
related_names = {
name for name in related_names if cls.Meta.model_fields[name].nullable
}
return related_names
@classmethod @classmethod
def _iterate_related_models( # noqa: CCR001 def _iterate_related_models( # noqa: CCR001
cls, cls, node_list: NodeList = None, source_relation: str = None
visited: Set[str] = None,
source_visited: Set[str] = None,
source_relation: str = None,
source_model: Union[Type["Model"], Type["RelationMixin"]] = None,
) -> List[str]: ) -> List[str]:
""" """
Iterates related models recursively to extract relation strings of Iterates related models recursively to extract relation strings of
nested not visited models. nested not visited models.
:param visited: set of already visited models
:type visited: Set[str]
:param source_relation: name of the current relation
:type source_relation: str
:param source_model: model from which relation comes in nested relations
:type source_model: Type["Model"]
:return: list of relation strings to be passed to select_related :return: list of relation strings to be passed to select_related
:rtype: List[str] :rtype: List[str]
""" """
source_visited = source_visited or cls._populate_source_model_prefixes() if not node_list:
node_list = NodeList()
current_node = node_list.add(node_class=cls)
else:
current_node = node_list[-1]
relations = cls.extract_related_names() relations = cls.extract_related_names()
processed_relations = [] processed_relations = []
for relation in relations: for relation in relations:
if not current_node.visited(relation):
target_model = cls.Meta.model_fields[relation].to target_model = cls.Meta.model_fields[relation].to
if cls._is_reverse_side_of_same_relation(source_model, target_model): node_list.add(
continue node_class=target_model,
if target_model not in source_visited or not source_model: relation_name=relation,
parent_node=current_node,
)
deep_relations = target_model._iterate_related_models( deep_relations = target_model._iterate_related_models(
visited=visited, source_relation=relation, node_list=node_list
source_visited=source_visited,
source_relation=relation,
source_model=cls,
) )
processed_relations.extend(deep_relations) processed_relations.extend(deep_relations)
else:
processed_relations.append(relation)
return cls._get_final_relations(processed_relations, source_relation) return cls._get_final_relations(processed_relations, source_relation)
@staticmethod
def _is_reverse_side_of_same_relation(
source_model: Optional[Union[Type["Model"], Type["RelationMixin"]]],
target_model: Type["Model"],
) -> bool:
"""
Alias to check if source model is the same as target
:param source_model: source model - relation comes from it
:type source_model: Type["Model"]
:param target_model: target model - relation leads to it
:type target_model: Type["Model"]
:return: result of the check
:rtype: bool
"""
return bool(source_model and target_model == source_model)
@staticmethod @staticmethod
def _get_final_relations( def _get_final_relations(
processed_relations: List, source_relation: Optional[str] processed_relations: List, source_relation: Optional[str]
@ -212,12 +166,3 @@ class RelationMixin:
else: else:
final_relations = [source_relation] if source_relation else [] final_relations = [source_relation] if source_relation else []
return final_relations return final_relations
@classmethod
def _populate_source_model_prefixes(cls) -> Set:
relations = cls.extract_related_names()
visited = {cls}
for relation in relations:
target_model = cls.Meta.model_fields[relation].to
visited.add(target_model)
return visited

View File

@ -32,11 +32,29 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
:rtype: Dict[str, str] :rtype: Dict[str, str]
""" """
new_kwargs = cls._remove_pk_from_kwargs(new_kwargs) new_kwargs = cls._remove_pk_from_kwargs(new_kwargs)
new_kwargs = cls._remove_not_ormar_fields(new_kwargs)
new_kwargs = cls.substitute_models_with_pks(new_kwargs) new_kwargs = cls.substitute_models_with_pks(new_kwargs)
new_kwargs = cls.populate_default_values(new_kwargs) new_kwargs = cls.populate_default_values(new_kwargs)
new_kwargs = cls.translate_columns_to_aliases(new_kwargs) new_kwargs = cls.translate_columns_to_aliases(new_kwargs)
return new_kwargs return new_kwargs
@classmethod
def _remove_not_ormar_fields(cls, new_kwargs: dict) -> dict:
"""
Removes primary key for if it's nullable or autoincrement pk field,
and it's set to None.
:param new_kwargs: dictionary of model that is about to be saved
:type new_kwargs: Dict[str, str]
:return: dictionary of model that is about to be saved
:rtype: Dict[str, str]
"""
ormar_fields = {
k for k, v in cls.Meta.model_fields.items() if not v.pydantic_only
}
new_kwargs = {k: v for k, v in new_kwargs.items() if k in ormar_fields}
return new_kwargs
@classmethod @classmethod
def _remove_pk_from_kwargs(cls, new_kwargs: dict) -> dict: def _remove_pk_from_kwargs(cls, new_kwargs: dict) -> dict:
""" """

View File

@ -4,7 +4,6 @@ from typing import (
List, List,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Tuple,
TypeVar, TypeVar,
Union, Union,
) )
@ -14,7 +13,7 @@ from ormar.exceptions import ModelPersistenceError, NoMatch
from ormar.models import NewBaseModel # noqa I100 from ormar.models import NewBaseModel # noqa I100
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
from ormar.models.model_row import ModelRow from ormar.models.model_row import ModelRow
from ormar.queryset.utils import subtract_dict, translate_list_to_dict
T = TypeVar("T", bound="Model") T = TypeVar("T", bound="Model")
@ -101,8 +100,13 @@ class Model(ModelRow):
return self return self
async def save_related( # noqa: CCR001 async def save_related( # noqa: CCR001
self, follow: bool = False, visited: Set = None, update_count: int = 0 self,
) -> int: # noqa: CCR001 follow: bool = False,
save_all: bool = False,
relation_map: Dict = None,
exclude: Union[Set, Dict] = None,
update_count: int = 0,
) -> int:
""" """
Triggers a upsert method on all related models Triggers a upsert method on all related models
if the instances are not already saved. if the instances are not already saved.
@ -118,77 +122,89 @@ class Model(ModelRow):
Model A but will never follow into Model C. Model A but will never follow into Model C.
Nested relations of those kind need to be persisted manually. Nested relations of those kind need to be persisted manually.
:param exclude: items to exclude during saving of relations
:type exclude: Union[Set, Dict]
:param relation_map: map of relations to follow
:type relation_map: Dict
:param save_all: flag if all models should be saved or only not saved ones
:type save_all: bool
:param follow: flag to trigger deep save - :param follow: flag to trigger deep save -
by default only directly related models are saved by default only directly related models are saved
with follow=True also related models of related models are saved with follow=True also related models of related models are saved
:type follow: bool :type follow: bool
:param visited: internal parameter for recursive calls - already visited models
:type visited: Set
:param update_count: internal parameter for recursive calls - :param update_count: internal parameter for recursive calls -
number of updated instances number of updated instances
:type update_count: int :type update_count: int
:return: number of updated/saved models :return: number of updated/saved models
:rtype: int :rtype: int
""" """
if not visited: relation_map = (
visited = {self.__class__} relation_map
else: if relation_map is not None
visited = {x for x in visited} else translate_list_to_dict(self._iterate_related_models())
visited.add(self.__class__) )
if exclude and isinstance(exclude, Set):
exclude = translate_list_to_dict(exclude)
relation_map = subtract_dict(relation_map, exclude or {})
for related in self.extract_related_names(): for related in self.extract_related_names():
if ( if relation_map and related in relation_map:
self.Meta.model_fields[related].virtual value = getattr(self, related)
or self.Meta.model_fields[related].is_multi if value:
): update_count = await self._update_and_follow(
for rel in getattr(self, related): value=value,
update_count, visited = await self._update_and_follow(
rel=rel,
follow=follow, follow=follow,
visited=visited, save_all=save_all,
relation_map=self._skip_ellipsis( # type: ignore
relation_map, related, default_return={}
),
update_count=update_count, update_count=update_count,
) )
visited.add(self.Meta.model_fields[related].to)
else:
rel = getattr(self, related)
update_count, visited = await self._update_and_follow(
rel=rel, follow=follow, visited=visited, update_count=update_count
)
visited.add(rel.__class__)
return update_count return update_count
@staticmethod @staticmethod
async def _update_and_follow( async def _update_and_follow(
rel: "Model", follow: bool, visited: Set, update_count: int value: Union["Model", List["Model"]],
) -> Tuple[int, Set]: follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
""" """
Internal method used in save_related to follow related models and update numbers Internal method used in save_related to follow related models and update numbers
of updated related instances. of updated related instances.
:param rel: Model to follow :param value: Model to follow
:type rel: Model :type value: Model
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save - :param follow: flag to trigger deep save -
by default only directly related models are saved by default only directly related models are saved
with follow=True also related models of related models are saved with follow=True also related models of related models are saved
:type follow: bool :type follow: bool
:param visited: internal parameter for recursive calls - already visited models
:type visited: Set
:param update_count: internal parameter for recursive calls - :param update_count: internal parameter for recursive calls -
number of updated instances number of updated instances
:type update_count: int :type update_count: int
:return: tuple of update count and visited :return: tuple of update count and visited
:rtype: Tuple[int, Set] :rtype: int
""" """
if follow and rel.__class__ not in visited: if not isinstance(value, list):
update_count = await rel.save_related( value = [value]
follow=follow, visited=visited, update_count=update_count
)
if not rel.saved:
await rel.upsert()
update_count += 1
return update_count, visited
async def update(self: T, **kwargs: Any) -> T: for val in value:
if (not val.saved or save_all) and not val.__pk_only__:
await val.upsert()
update_count += 1
if follow:
update_count = await val.save_related(
follow=follow,
save_all=save_all,
relation_map=relation_map,
update_count=update_count,
)
return update_count
async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T:
""" """
Performs update of Model instance in the database. Performs update of Model instance in the database.
Fields can be updated before or you can pass them as kwargs. Fields can be updated before or you can pass them as kwargs.
@ -197,6 +213,8 @@ class Model(ModelRow):
Sets model save status to True. Sets model save status to True.
:param _columns: list of columns to update, if None all are updated
:type _columns: List
:raises ModelPersistenceError: If the pk column is not set :raises ModelPersistenceError: If the pk column is not set
:param kwargs: list of fields to update as field=value pairs :param kwargs: list of fields to update as field=value pairs
@ -217,6 +235,8 @@ class Model(ModelRow):
) )
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
if _columns:
self_fields = {k: v for k, v in self_fields.items() if k in _columns}
self_fields = self.translate_columns_to_aliases(self_fields) self_fields = self.translate_columns_to_aliases(self_fields)
expr = self.Meta.table.update().values(**self_fields) expr = self.Meta.table.update().values(**self_fields)
expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname))

View File

@ -64,7 +64,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
the logic concerned with database connection and data persistance. the logic concerned with database connection and data persistance.
""" """
__slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column") __slots__ = ("_orm_id", "_orm_saved", "_orm", "_pk_column", "__pk_only__")
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
pk: Any pk: Any
@ -134,6 +134,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
) )
pk_only = kwargs.pop("__pk_only__", False) pk_only = kwargs.pop("__pk_only__", False)
object.__setattr__(self, "__pk_only__", pk_only)
excluded: Set[str] = kwargs.pop("__excluded__", set()) excluded: Set[str] = kwargs.pop("__excluded__", set())
if "pk" in kwargs: if "pk" in kwargs:
@ -267,9 +269,13 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if item == "pk": if item == "pk":
return object.__getattribute__(self, "__dict__").get(self.Meta.pkname, None) return object.__getattribute__(self, "__dict__").get(self.Meta.pkname, None)
if item in object.__getattribute__(self, "extract_related_names")(): if item in object.__getattribute__(self, "extract_related_names")():
return self._extract_related_model_instead_of_field(item) return object.__getattribute__(
self, "_extract_related_model_instead_of_field"
)(item)
if item in object.__getattribute__(self, "extract_through_names")(): if item in object.__getattribute__(self, "extract_through_names")():
return self._extract_related_model_instead_of_field(item) return object.__getattribute__(
self, "_extract_related_model_instead_of_field"
)(item)
if item in object.__getattribute__(self, "Meta").property_fields: if item in object.__getattribute__(self, "Meta").property_fields:
value = object.__getattribute__(self, item) value = object.__getattribute__(self, item)
return value() if callable(value) else value return value() if callable(value) else value
@ -337,8 +343,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
return ( return (
self._orm_id == other._orm_id self._orm_id == other._orm_id
or (self.pk == other.pk and self.pk is not None) or (self.pk == other.pk and self.pk is not None)
or self.dict(exclude=self.extract_related_names()) or (
== other.dict(exclude=other.extract_related_names()) (self.pk is None and other.pk is None)
and {
k: v
for k, v in self.__dict__.items()
if k not in self.extract_related_names()
}
== {
k: v
for k, v in other.__dict__.items()
if k not in other.extract_related_names()
}
)
) )
@classmethod @classmethod
@ -489,6 +506,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
@staticmethod @staticmethod
def _extract_nested_models_from_list( def _extract_nested_models_from_list(
relation_map: Dict,
models: MutableSequence, models: MutableSequence,
include: Union[Set, Dict, None], include: Union[Set, Dict, None],
exclude: Union[Set, Dict, None], exclude: Union[Set, Dict, None],
@ -509,14 +527,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
for model in models: for model in models:
try: try:
result.append( result.append(
model.dict(nested=True, include=include, exclude=exclude,) model.dict(
relation_map=relation_map, include=include, exclude=exclude,
)
) )
except ReferenceError: # pragma no cover except ReferenceError: # pragma no cover
continue continue
return result return result
def _skip_ellipsis( def _skip_ellipsis(
self, items: Union[Set, Dict, None], key: str self, items: Union[Set, Dict, None], key: str, default_return: Any = None
) -> Union[Set, Dict, None]: ) -> Union[Set, Dict, None]:
""" """
Helper to traverse the include/exclude dictionaries. Helper to traverse the include/exclude dictionaries.
@ -531,11 +551,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
:rtype: Union[Set, Dict, None] :rtype: Union[Set, Dict, None]
""" """
result = self.get_child(items, key) result = self.get_child(items, key)
return result if result is not Ellipsis else None return result if result is not Ellipsis else default_return
def _extract_nested_models( # noqa: CCR001 def _extract_nested_models( # noqa: CCR001
self, self,
nested: bool, relation_map: Dict,
dict_instance: Dict, dict_instance: Dict,
include: Optional[Dict], include: Optional[Dict],
exclude: Optional[Dict], exclude: Optional[Dict],
@ -559,18 +579,23 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
fields = self._get_related_not_excluded_fields(include=include, exclude=exclude) fields = self._get_related_not_excluded_fields(include=include, exclude=exclude)
for field in fields: for field in fields:
if self.Meta.model_fields[field].virtual and nested: if not relation_map or field not in relation_map:
continue continue
nested_model = getattr(self, field) nested_model = getattr(self, field)
if isinstance(nested_model, MutableSequence): if isinstance(nested_model, MutableSequence):
dict_instance[field] = self._extract_nested_models_from_list( dict_instance[field] = self._extract_nested_models_from_list(
relation_map=self._skip_ellipsis( # type: ignore
relation_map, field, default_return=dict()
),
models=nested_model, models=nested_model,
include=self._skip_ellipsis(include, field), include=self._skip_ellipsis(include, field),
exclude=self._skip_ellipsis(exclude, field), exclude=self._skip_ellipsis(exclude, field),
) )
elif nested_model is not None: elif nested_model is not None:
dict_instance[field] = nested_model.dict( dict_instance[field] = nested_model.dict(
nested=True, relation_map=self._skip_ellipsis(
relation_map, field, default_return=dict()
),
include=self._skip_ellipsis(include, field), include=self._skip_ellipsis(include, field),
exclude=self._skip_ellipsis(exclude, field), exclude=self._skip_ellipsis(exclude, field),
) )
@ -588,7 +613,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
exclude_unset: bool = False, exclude_unset: bool = False,
exclude_defaults: bool = False, exclude_defaults: bool = False,
exclude_none: bool = False, exclude_none: bool = False,
nested: bool = False, relation_map: Dict = None,
) -> "DictStrAny": # noqa: A003' ) -> "DictStrAny": # noqa: A003'
""" """
@ -613,14 +638,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
:type exclude_defaults: bool :type exclude_defaults: bool
:param exclude_none: flag to exclude None values - passed to pydantic :param exclude_none: flag to exclude None values - passed to pydantic
:type exclude_none: bool :type exclude_none: bool
:param nested: flag if the current model is nested :param relation_map: map of the relations to follow to avoid circural deps
:type nested: bool :type relation_map: Dict
:return: :return:
:rtype: :rtype:
""" """
dict_instance = super().dict( dict_instance = super().dict(
include=include, include=include,
exclude=self._update_excluded_with_related_not_required(exclude, nested), exclude=self._update_excluded_with_related(exclude),
by_alias=by_alias, by_alias=by_alias,
skip_defaults=skip_defaults, skip_defaults=skip_defaults,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
@ -633,8 +658,15 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if exclude and isinstance(exclude, Set): if exclude and isinstance(exclude, Set):
exclude = translate_list_to_dict(exclude) exclude = translate_list_to_dict(exclude)
relation_map = (
relation_map
if relation_map is not None
else translate_list_to_dict(self._iterate_related_models())
)
pk_only = object.__getattribute__(self, "__pk_only__")
if relation_map and not pk_only:
dict_instance = self._extract_nested_models( dict_instance = self._extract_nested_models(
nested=nested, relation_map=relation_map,
dict_instance=dict_instance, dict_instance=dict_instance,
include=include, # type: ignore include=include, # type: ignore
exclude=exclude, # type: ignore exclude=exclude, # type: ignore
@ -714,7 +746,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
:rtype: Dict :rtype: Dict
""" """
related_names = self.extract_related_names() related_names = self.extract_related_names()
self_fields = self.dict(exclude=related_names) self_fields = {k: v for k, v in self.__dict__.items() if k not in related_names}
return self_fields return self_fields
def _extract_model_db_fields(self) -> Dict: def _extract_model_db_fields(self) -> Dict:

View File

@ -14,6 +14,7 @@ quick_access_set = {
"__json_encoder__", "__json_encoder__",
"__post_root_validators__", "__post_root_validators__",
"__pre_root_validators__", "__pre_root_validators__",
"__private_attributes__",
"__same__", "__same__",
"_calculate_keys", "_calculate_keys",
"_choices_fields", "_choices_fields",
@ -26,8 +27,10 @@ quick_access_set = {
"_extract_related_model_instead_of_field", "_extract_related_model_instead_of_field",
"_get_related_not_excluded_fields", "_get_related_not_excluded_fields",
"_get_value", "_get_value",
"_init_private_attributes",
"_is_conversion_to_json_needed", "_is_conversion_to_json_needed",
"_iter", "_iter",
"_iterate_related_models",
"_orm", "_orm",
"_orm_id", "_orm_id",
"_orm_saved", "_orm_saved",
@ -40,8 +43,10 @@ quick_access_set = {
"delete", "delete",
"dict", "dict",
"extract_related_names", "extract_related_names",
"extract_related_fields",
"extract_through_names", "extract_through_names",
"update_from_dict", "update_from_dict",
"get_child",
"get_column_alias", "get_column_alias",
"get_column_name_from_alias", "get_column_name_from_alias",
"get_filtered_names_to_extract", "get_filtered_names_to_extract",
@ -52,9 +57,11 @@ quick_access_set = {
"json", "json",
"keys", "keys",
"load", "load",
"load_all",
"pk_column", "pk_column",
"pk_type", "pk_type",
"populate_default_values", "populate_default_values",
"prepare_model_to_save",
"remove", "remove",
"resolve_relation_field", "resolve_relation_field",
"resolve_relation_name", "resolve_relation_name",
@ -62,6 +69,7 @@ quick_access_set = {
"save_related", "save_related",
"saved", "saved",
"set_save_status", "set_save_status",
"signals",
"translate_aliases_to_columns", "translate_aliases_to_columns",
"translate_columns_to_aliases", "translate_columns_to_aliases",
"update", "update",

118
ormar/models/traversible.py Normal file
View File

@ -0,0 +1,118 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type
if TYPE_CHECKING: # pragma no cover
from ormar.models.mixins.relation_mixin import RelationMixin
class NodeList:
"""
Helper class that helps with iterating nested models
"""
def __init__(self) -> None:
self.node_list: List["Node"] = []
def __getitem__(self, item: Any) -> Any:
return self.node_list.__getitem__(item)
def add(
self,
node_class: Type["RelationMixin"],
relation_name: str = None,
parent_node: "Node" = None,
) -> "Node":
"""
Adds new Node or returns the existing one
:param node_class: Model in current node
:type node_class: ormar.models.metaclass.ModelMetaclass
:param relation_name: name of the current relation
:type relation_name: str
:param parent_node: parent node
:type parent_node: Optional[Node]
:return: returns new or already existing node
:rtype: Node
"""
existing_node = self.find(
relation_name=relation_name, node_class=node_class, parent_node=parent_node
)
if not existing_node:
current_node = Node(
node_class=node_class,
relation_name=relation_name,
parent_node=parent_node,
)
self.node_list.append(current_node)
return current_node
return existing_node # pragma: no cover
def find(
self,
node_class: Type["RelationMixin"],
relation_name: Optional[str] = None,
parent_node: "Node" = None,
) -> Optional["Node"]:
"""
Searches for existing node with given parameters
:param node_class: Model in current node
:type node_class: ormar.models.metaclass.ModelMetaclass
:param relation_name: name of the current relation
:type relation_name: str
:param parent_node: parent node
:type parent_node: Optional[Node]
:return: returns already existing node or None
:rtype: Optional[Node]
"""
for node in self.node_list:
if (
node.node_class == node_class
and node.parent_node == parent_node
and node.relation_name == relation_name
):
return node # pragma: no cover
return None
class Node:
def __init__(
self,
node_class: Type["RelationMixin"],
relation_name: str = None,
parent_node: "Node" = None,
) -> None:
self.relation_name = relation_name
self.node_class = node_class
self.parent_node = parent_node
self.visited_children: List["Node"] = []
if self.parent_node:
self.parent_node.visited_children.append(self)
def __repr__(self) -> str: # pragma: no cover
return (
f"{self.node_class.get_name(lower=False)}, "
f"relation:{self.relation_name}, "
f"parent: {self.parent_node}"
)
def visited(self, relation_name: str) -> bool:
"""
Checks if given relation was already visited.
Relation was visited if it's name is in current node children.
Relation was visited if one of the parent node had the same Model class
:param relation_name: name of relation
:type relation_name: str
:return: result of the check
:rtype: bool
"""
target_model = self.node_class.Meta.model_fields[relation_name].to
if self.parent_node:
node = self
while node.parent_node:
node = node.parent_node
if node.node_class == target_model:
return True
return False

View File

@ -650,7 +650,7 @@ class QuerySet(Generic[T]):
:return: number of updated rows :return: number of updated rows
:rtype: int :rtype: int
""" """
if not each and not self.filter_clauses: if not each and not (self.filter_clauses or self.exclude_clauses):
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot update without filtering the queryset first. " "You cannot update without filtering the queryset first. "
"If you want to update all rows use update(each=True, **kwargs)" "If you want to update all rows use update(each=True, **kwargs)"
@ -666,6 +666,9 @@ class QuerySet(Generic[T]):
expr = FilterQuery(filter_clauses=self.filter_clauses).apply( expr = FilterQuery(filter_clauses=self.filter_clauses).apply(
self.table.update().values(**updates) self.table.update().values(**updates)
) )
expr = FilterQuery(filter_clauses=self.exclude_clauses, exclude=True).apply(
expr
)
return await self.database.execute(expr) return await self.database.execute(expr)
async def delete(self, each: bool = False, **kwargs: Any) -> int: async def delete(self, each: bool = False, **kwargs: Any) -> int:
@ -684,7 +687,7 @@ class QuerySet(Generic[T]):
""" """
if kwargs: if kwargs:
return await self.filter(**kwargs).delete() return await self.filter(**kwargs).delete()
if not each and not self.filter_clauses: if not each and not (self.filter_clauses or self.exclude_clauses):
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot delete without filtering the queryset first. " "You cannot delete without filtering the queryset first. "
"If you want to delete all rows use delete(each=True)" "If you want to delete all rows use delete(each=True)"
@ -692,6 +695,9 @@ class QuerySet(Generic[T]):
expr = FilterQuery(filter_clauses=self.filter_clauses).apply( expr = FilterQuery(filter_clauses=self.filter_clauses).apply(
self.table.delete() self.table.delete()
) )
expr = FilterQuery(filter_clauses=self.exclude_clauses, exclude=True).apply(
expr
)
return await self.database.execute(expr) return await self.database.execute(expr)
def paginate(self, page: int, page_size: int = 20) -> "QuerySet[T]": def paginate(self, page: int, page_size: int = 20) -> "QuerySet[T]":

View File

@ -18,7 +18,7 @@ if TYPE_CHECKING: # pragma no cover
def check_node_not_dict_or_not_last_node( def check_node_not_dict_or_not_last_node(
part: str, parts: List, current_level: Any part: str, is_last: bool, current_level: Any
) -> bool: ) -> bool:
""" """
Checks if given name is not present in the current level of the structure. Checks if given name is not present in the current level of the structure.
@ -36,7 +36,7 @@ def check_node_not_dict_or_not_last_node(
:return: result of the check :return: result of the check
:rtype: bool :rtype: bool
""" """
return (part not in current_level and part != parts[-1]) or ( return (part not in current_level and not is_last) or (
part in current_level and not isinstance(current_level[part], dict) part in current_level and not isinstance(current_level[part], dict)
) )
@ -71,9 +71,10 @@ def translate_list_to_dict( # noqa: CCR001
else: else:
def_val = "asc" def_val = "asc"
for part in parts: for ind, part in enumerate(parts):
is_last = ind == len(parts) - 1
if check_node_not_dict_or_not_last_node( if check_node_not_dict_or_not_last_node(
part=part, parts=parts, current_level=current_level part=part, is_last=is_last, current_level=current_level
): ):
current_level[part] = dict() current_level[part] = dict()
elif part not in current_level: elif part not in current_level:
@ -127,6 +128,49 @@ def update(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
return current_dict return current_dict
def subtract_dict(current_dict: Any, updating_dict: Any) -> Dict: # noqa: CCR001
"""
Update one dict with another but with regard for nested keys.
That way nested sets are unionised, dicts updated and
only other values are overwritten.
:param current_dict: dict to update
:type current_dict: Dict[str, ellipsis]
:param updating_dict: dict with values to update
:type updating_dict: Dict
:return: combination of both dicts
:rtype: Dict
"""
for key, value in updating_dict.items():
old_key = current_dict.get(key, {})
new_value: Optional[Union[Dict, Set]] = None
if not old_key:
continue
if isinstance(value, set) and isinstance(old_key, set):
new_value = old_key.difference(value)
elif isinstance(value, (set, collections.abc.Mapping)) and isinstance(
old_key, (set, collections.abc.Mapping)
):
value = (
convert_set_to_required_dict(value)
if not isinstance(value, collections.abc.Mapping)
else value
)
old_key = (
convert_set_to_required_dict(old_key)
if not isinstance(old_key, collections.abc.Mapping)
else old_key
)
new_value = subtract_dict(old_key, value)
if new_value:
current_dict[key] = new_value
else:
current_dict.pop(key, None)
return current_dict
def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict: def update_dict_from_list(curr_dict: Dict, list_to_update: Union[List, Set]) -> Dict:
""" """
Converts the list into dictionary and later performs special update, where Converts the list into dictionary and later performs special update, where

View File

@ -389,7 +389,11 @@ class QuerysetProxy(Generic[T]):
:rtype: int :rtype: int
""" """
# queryset proxy always have one filter for pk of parent model # queryset proxy always have one filter for pk of parent model
if not each and len(self.queryset.filter_clauses) == 1: if (
not each
and (len(self.queryset.filter_clauses) + len(self.queryset.exclude_clauses))
== 1
):
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot update without filtering the queryset first. " "You cannot update without filtering the queryset first. "
"If you want to update all rows use update(each=True, **kwargs)" "If you want to update all rows use update(each=True, **kwargs)"

View File

@ -127,7 +127,7 @@ class RelationProxy(Generic[T], list):
related_field = self.relation.to.Meta.model_fields[related_field_name] related_field = self.relation.to.Meta.model_fields[related_field_name]
pkname = self._owner.get_column_alias(self._owner.Meta.pkname) pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
self._check_if_model_saved() self._check_if_model_saved()
kwargs = {f"{related_field.get_alias()}__{pkname}": self._owner.pk} kwargs = {f"{related_field.name}__{pkname}": self._owner.pk}
queryset = ( queryset = (
ormar.QuerySet( ormar.QuerySet(
model_cls=self.relation.to, proxy_source_model=self._owner.__class__ model_cls=self.relation.to, proxy_source_model=self._owner.__class__

View File

@ -30,6 +30,9 @@ renderer:
- title: Excludable Items - title: Excludable Items
contents: contents:
- models.excludable.* - models.excludable.*
- title: Traversible
contents:
- models.traversible.*
- title: Model Table Proxy - title: Model Table Proxy
contents: contents:
- models.modelproxy.* - models.modelproxy.*

View File

View File

View File

View File

@ -48,7 +48,11 @@ def test_read_main():
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == { assert response.json() == {
"category": {"id": None, "name": "test cat"}, "category": {
"id": None,
"items": [{"id": 1, "name": "test"}],
"name": "test cat",
},
"id": 1, "id": 1,
"name": "test", "name": "test",
} }

View File

@ -6,7 +6,7 @@ from fastapi import FastAPI
from starlette.testclient import TestClient from starlette.testclient import TestClient
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
from tests.test_inheritance_concrete import ( # type: ignore from tests.test_inheritance.test_inheritance_concrete import ( # type: ignore
Category, Category,
Subject, Subject,
Person, Person,

View File

@ -6,7 +6,7 @@ from fastapi import FastAPI
from starlette.testclient import TestClient from starlette.testclient import TestClient
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
from tests.test_inheritance_mixins import Category, Subject, metadata, db as database # type: ignore from tests.test_inheritance.test_inheritance_mixins import Category, Subject, metadata, db as database # type: ignore
app = FastAPI() app = FastAPI()
app.state.database = database app.state.database = database

View File

View File

View File

View File

View File

@ -91,6 +91,14 @@ async def test_saving_related_fk_rel():
assert count == 1 assert count == 1
assert comp.hq.saved assert comp.hq.saved
comp.hq.name = "Suburbs 2"
assert not comp.hq.saved
assert comp.saved
count = await comp.save_related(exclude={"hq"})
assert count == 0
assert not comp.hq.saved
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_saving_many_to_many(): async def test_saving_many_to_many():
@ -110,6 +118,9 @@ async def test_saving_many_to_many():
count = await hq.save_related() count = await hq.save_related()
assert count == 0 assert count == 0
count = await hq.save_related(save_all=True)
assert count == 2
hq.nicks[0].name = "Kabucha" hq.nicks[0].name = "Kabucha"
hq.nicks[1].name = "Kabucha2" hq.nicks[1].name = "Kabucha2"
assert not hq.nicks[0].saved assert not hq.nicks[0].saved
@ -120,6 +131,16 @@ async def test_saving_many_to_many():
assert hq.nicks[0].saved assert hq.nicks[0].saved
assert hq.nicks[1].saved assert hq.nicks[1].saved
hq.nicks[0].name = "Kabucha a"
hq.nicks[1].name = "Kabucha2 a"
assert not hq.nicks[0].saved
assert not hq.nicks[1].saved
count = await hq.save_related(exclude={"nicks": ...})
assert count == 0
assert not hq.nicks[0].saved
assert not hq.nicks[1].saved
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_saving_reversed_relation(): async def test_saving_reversed_relation():
@ -208,3 +229,16 @@ async def test_saving_nested():
assert hq.nicks[0].level.saved assert hq.nicks[0].level.saved
assert hq.nicks[1].saved assert hq.nicks[1].saved
assert hq.nicks[1].level.saved assert hq.nicks[1].level.saved
hq.nicks[0].level.name = "Low 2"
hq.nicks[1].level.name = "Medium 2"
assert not hq.nicks[0].level.saved
assert not hq.nicks[1].level.saved
assert hq.nicks[0].saved
assert hq.nicks[1].saved
count = await hq.save_related(follow=True, exclude={"nicks": {"level"}})
assert count == 0
assert hq.nicks[0].saved
assert not hq.nicks[0].level.saved
assert hq.nicks[1].saved
assert not hq.nicks[1].level.saved

View File

@ -0,0 +1,111 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Director(ormar.Model):
class Meta:
tablename = "directors"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="first_name")
last_name: str = ormar.String(max_length=100, nullable=False, name="last_name")
class Movie(ormar.Model):
class Meta:
tablename = "movies"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="title")
year: int = ormar.Integer()
profit: float = ormar.Float()
director: Optional[Director] = ormar.ForeignKey(Director)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_updating_selected_columns():
async with database:
director1 = await Director(name="Peter", last_name="Jackson").save()
director2 = await Director(name="James", last_name="Cameron").save()
lotr = await Movie(
name="LOTR", year=2001, director=director1, profit=1.140
).save()
lotr.name = "Lord of The Rings"
lotr.year = 2003
lotr.profit = 1.212
await lotr.update(_columns=["name"])
# before reload the field has current value even if not saved
assert lotr.year == 2003
lotr = await Movie.objects.get()
assert lotr.name == "Lord of The Rings"
assert lotr.year == 2001
assert round(lotr.profit, 3) == 1.140
assert lotr.director.pk == director1.pk
lotr.year = 2003
lotr.profit = 1.212
lotr.director = director2
await lotr.update(_columns=["year", "profit"])
lotr = await Movie.objects.get()
assert lotr.year == 2003
assert round(lotr.profit, 3) == 1.212
assert lotr.director.pk == director1.pk
@pytest.mark.asyncio
async def test_not_passing_columns_or_empty_list_saves_all():
async with database:
director = await Director(name="James", last_name="Cameron").save()
terminator = await Movie(
name="Terminator", year=1984, director=director, profit=0.078
).save()
terminator.name = "Terminator 2"
terminator.year = 1991
terminator.profit = 0.520
await terminator.update(_columns=[])
terminator = await Movie.objects.get()
assert terminator.name == "Terminator 2"
assert terminator.year == 1991
assert round(terminator.profit, 3) == 0.520
terminator.name = "Terminator 3"
terminator.year = 2003
terminator.profit = 0.433
await terminator.update()
terminator = await terminator.load()
assert terminator.name == "Terminator 3"
assert terminator.year == 2003
assert round(terminator.profit, 3) == 0.433

View File

View File

View File

@ -0,0 +1,158 @@
import databases
import pytest
from sqlalchemy import func
import ormar
import sqlalchemy
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Chart(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "charts"
database = database
metadata = metadata
chart_id = ormar.Integer(primary_key=True, autoincrement=True)
name = ormar.String(max_length=200, unique=True, index=True)
query_text = ormar.Text()
datasets = ormar.JSON()
layout = ormar.JSON()
data_config = ormar.JSON()
created_date = ormar.DateTime(server_default=func.now())
library = ormar.String(max_length=200, default="plotly")
used_filters = ormar.JSON()
class Report(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "reports"
database = database
metadata = metadata
report_id = ormar.Integer(primary_key=True, autoincrement=True)
name = ormar.String(max_length=200, unique=True, index=True)
filters_position = ormar.String(max_length=200)
created_date = ormar.DateTime(server_default=func.now())
class Language(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "languages"
database = database
metadata = metadata
language_id = ormar.Integer(primary_key=True, autoincrement=True)
code = ormar.String(max_length=5)
name = ormar.String(max_length=200)
class TranslationNode(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "translation_nodes"
database = database
metadata = metadata
node_id = ormar.Integer(primary_key=True, autoincrement=True)
node_type = ormar.String(max_length=200)
class Translation(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "translations"
database = database
metadata = metadata
translation_id = ormar.Integer(primary_key=True, autoincrement=True)
node_id = ormar.ForeignKey(TranslationNode, related_name="translations")
language = ormar.ForeignKey(Language, name="language_id")
value = ormar.String(max_length=500)
class Filter(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "filters"
database = database
metadata = metadata
filter_id = ormar.Integer(primary_key=True, autoincrement=True)
name = ormar.String(max_length=200, unique=True, index=True)
label = ormar.String(max_length=200)
query_text = ormar.Text()
allow_multiselect = ormar.Boolean(default=True)
created_date = ormar.DateTime(server_default=func.now())
is_dynamic = ormar.Boolean(default=True)
is_date = ormar.Boolean(default=False)
translation = ormar.ForeignKey(TranslationNode, name="translation_node_id")
class FilterValue(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "filter_values"
database = database
metadata = metadata
value_id = ormar.Integer(primary_key=True, autoincrement=True)
value = ormar.String(max_length=300)
label = ormar.String(max_length=300)
filter = ormar.ForeignKey(Filter, name="filter_id", related_name="values")
translation = ormar.ForeignKey(TranslationNode, name="translation_node_id")
class FilterXReport(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "filters_x_reports"
database = database
metadata = metadata
filter_x_report_id = ormar.Integer(primary_key=True)
filter = ormar.ForeignKey(Filter, name="filter_id", related_name="reports")
report = ormar.ForeignKey(Report, name="report_id", related_name="filters")
sort_order = ormar.Integer()
default_value = ormar.Text()
is_visible = ormar.Boolean()
class ChartXReport(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "charts_x_reports"
database = database
metadata = metadata
chart_x_report_id = ormar.Integer(primary_key=True)
chart = ormar.ForeignKey(Chart, name="chart_id", related_name="reports")
report = ormar.ForeignKey(Report, name="report_id", related_name="charts")
sort_order = ormar.Integer()
width = ormar.Integer()
class ChartColumn(ormar.Model):
class Meta(ormar.ModelMeta):
tablename = "charts_columns"
database = database
metadata = metadata
column_id = ormar.Integer(primary_key=True, autoincrement=True)
chart = ormar.ForeignKey(Chart, name="chart_id", related_name="columns")
column_name = ormar.String(max_length=200)
column_type = ormar.String(max_length=200)
translation = ormar.ForeignKey(TranslationNode, name="translation_node_id")
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_saving_related_fk_rel():
async with database:
async with database.transaction(force_rollback=True):
await Report.objects.select_all(follow=True).all()

View File

@ -0,0 +1,101 @@
from typing import Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class DataSource(ormar.Model):
class Meta(BaseMeta):
tablename = "datasources"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=200, unique=True, index=True)
class DataSourceTable(ormar.Model):
class Meta(BaseMeta):
tablename = "source_tables"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=200, index=True)
source: Optional[DataSource] = ormar.ForeignKey(
DataSource, name="source_id", related_name="tables", ondelete="CASCADE",
)
class DataSourceTableColumn(ormar.Model):
class Meta(BaseMeta):
tablename = "source_columns"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=200, index=True)
data_type: str = ormar.String(max_length=200)
table: Optional[DataSourceTable] = ormar.ForeignKey(
DataSourceTable, name="table_id", related_name="columns", ondelete="CASCADE",
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_double_nested_reverse_relation():
async with database:
data_source = await DataSource(name="local").save()
test_tables = [
{
"name": "test1",
"columns": [
{"name": "col1", "data_type": "test"},
{"name": "col2", "data_type": "test2"},
{"name": "col3", "data_type": "test3"},
],
},
{
"name": "test2",
"columns": [
{"name": "col4", "data_type": "test"},
{"name": "col5", "data_type": "test2"},
{"name": "col6", "data_type": "test3"},
],
},
]
data_source.tables = test_tables
await data_source.save_related(save_all=True, follow=True)
tables = await DataSourceTable.objects.all()
assert len(tables) == 2
columns = await DataSourceTableColumn.objects.all()
assert len(columns) == 6
data_source = (
await DataSource.objects.select_related("tables__columns")
.filter(tables__name__in=["test1", "test2"], name="local")
.get()
)
assert len(data_source.tables) == 2
assert len(data_source.tables[0].columns) == 3
assert data_source.tables[0].columns[0].name == "col1"
assert data_source.tables[0].columns[2].name == "col3"
assert len(data_source.tables[1].columns) == 3
assert data_source.tables[1].columns[0].name == "col4"
assert data_source.tables[1].columns[2].name == "col6"

View File

@ -0,0 +1,53 @@
from typing import Dict, List, Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class Chart(ormar.Model):
class Meta(BaseMeta):
tablename = "authors"
id: int = ormar.Integer(primary_key=True)
datasets = ormar.JSON()
class Config(ormar.Model):
class Meta(BaseMeta):
tablename = "books"
id: int = ormar.Integer(primary_key=True)
chart: Optional[Chart] = ormar.ForeignKey(Chart)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_list_field_that_is_not_relation_is_not_merged():
async with database:
chart = await Chart.objects.create(datasets=[{"test": "ok"}])
await Config.objects.create(chart=chart)
await Config.objects.create(chart=chart)
chart2 = await Chart.objects.select_related("configs").get()
assert len(chart2.datasets) == 1
assert chart2.datasets == [{"test": "ok"}]

View File

@ -18,7 +18,7 @@ class Album(ormar.Model):
metadata = metadata metadata = metadata
database = database database = database
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True, name="album_id")
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
is_best_seller: bool = ormar.Boolean(default=False) is_best_seller: bool = ormar.Boolean(default=False)
@ -29,7 +29,7 @@ class Writer(ormar.Model):
metadata = metadata metadata = metadata
database = database database = database
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True, name="writer_id")
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
@ -40,11 +40,11 @@ class Track(ormar.Model):
database = database database = database
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
album: Optional[Album] = ormar.ForeignKey(Album) album: Optional[Album] = ormar.ForeignKey(Album, name="album_id")
title: str = ormar.String(max_length=100) title: str = ormar.String(max_length=100)
position: int = ormar.Integer() position: int = ormar.Integer()
play_count: int = ormar.Integer(nullable=True) play_count: int = ormar.Integer(nullable=True)
written_by: Optional[Writer] = ormar.ForeignKey(Writer) written_by: Optional[Writer] = ormar.ForeignKey(Writer, name="writer_id")
async def get_sample_data(): async def get_sample_data():

View File

View File

Some files were not shown because too many files have changed in this diff Show More