From b3b1c156b528db997214dfc80881a85e009f9ce2 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 11 Apr 2021 18:43:23 +0200 Subject: [PATCH 1/6] add skip_reverse parameter, add links to related libs, fix weakref error, fix through error with extra=forbid --- README.md | 22 +- docs/index.md | 22 +- docs/releases.md | 26 ++ ormar/__init__.py | 2 +- ormar/fields/base.py | 2 + ormar/fields/foreign_key.py | 30 +++ ormar/fields/many_to_many.py | 25 +- ormar/models/helpers/relations.py | 4 + ormar/models/metaclass.py | 2 + ormar/models/mixins/relation_mixin.py | 20 +- ormar/models/model.py | 6 +- ormar/models/newbasemodel.py | 26 +- ormar/relations/querysetproxy.py | 6 +- ormar/relations/relation_proxy.py | 18 +- .../test_inheritance_concrete_fastapi.py | 21 ++ tests/test_fastapi/test_nested_saving.py | 106 +++++++++ .../test_fastapi/test_skip_reverse_models.py | 148 ++++++++++++ tests/test_fastapi/test_wekref_exclusion.py | 14 ++ tests/test_relations/test_skipping_reverse.py | 223 ++++++++++++++++++ 19 files changed, 675 insertions(+), 48 deletions(-) create mode 100644 tests/test_fastapi/test_nested_saving.py create mode 100644 tests/test_fastapi/test_skip_reverse_models.py create mode 100644 tests/test_relations/test_skipping_reverse.py diff --git a/README.md b/README.md index 4d2e6ea..13a7823 100644 --- a/README.md +++ b/README.md @@ -47,15 +47,35 @@ since they actually have to create and connect to database in most of the tests. Yet remember that those are - well - tests and not all solutions are suitable to be used in real life applications. +### Part of the `fastapi` ecosystem + +As part of the fastapi ecosystem `ormar` is supported in libraries that somehow work with databases. + +As of now `ormar` is supported by: + +* [`fastapi-users`](https://github.com/frankie567/fastapi-users) +* [`fastapi-crudrouter`](https://github.com/awtkns/fastapi-crudrouter) +* [`fastapi-pagination`](https://github.com/uriyyo/fastapi-pagination) + +If you maintain or use different library and would like it to support `ormar` let us know how we can help. + ### Dependencies Ormar is built with: - * [`SQLAlchemy core`][sqlalchemy-core] for query building. + * [`sqlalchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. * `typing_extensions` for python 3.6 - 3.7 +### Migrating from `sqlalchemy` + +If you currently use `sqlalchemy` and would like to switch to `ormar` check out the auto-translation +tool that can help you with translating existing sqlalchemy orm models so you do not have to do it manually. + +**Beta** versions available at github: [`sqlalchemy-to-ormar`](https://github.com/collerek/sqlalchemy-to-ormar) +or simply `pip install sqlalchemy-to-ormar` + ### Migrations & Database creation Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide diff --git a/docs/index.md b/docs/index.md index 4d2e6ea..13a7823 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,15 +47,35 @@ since they actually have to create and connect to database in most of the tests. Yet remember that those are - well - tests and not all solutions are suitable to be used in real life applications. +### Part of the `fastapi` ecosystem + +As part of the fastapi ecosystem `ormar` is supported in libraries that somehow work with databases. + +As of now `ormar` is supported by: + +* [`fastapi-users`](https://github.com/frankie567/fastapi-users) +* [`fastapi-crudrouter`](https://github.com/awtkns/fastapi-crudrouter) +* [`fastapi-pagination`](https://github.com/uriyyo/fastapi-pagination) + +If you maintain or use different library and would like it to support `ormar` let us know how we can help. + ### Dependencies Ormar is built with: - * [`SQLAlchemy core`][sqlalchemy-core] for query building. + * [`sqlalchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. * `typing_extensions` for python 3.6 - 3.7 +### Migrating from `sqlalchemy` + +If you currently use `sqlalchemy` and would like to switch to `ormar` check out the auto-translation +tool that can help you with translating existing sqlalchemy orm models so you do not have to do it manually. + +**Beta** versions available at github: [`sqlalchemy-to-ormar`](https://github.com/collerek/sqlalchemy-to-ormar) +or simply `pip install sqlalchemy-to-ormar` + ### Migrations & Database creation Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide diff --git a/docs/releases.md b/docs/releases.md index e22082c..18d5a8d 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,29 @@ +# 0.10.3 + +## ✨ Features + +* `ForeignKey` and `ManyToMany` now support `skip_reverse: bool = False` flag [#118](https://github.com/collerek/ormar/issues/118). + If you set `skip_reverse` flag internally the field is still registered on the other + side of the relationship so you can: + * `filter` by related models fields from reverse model + * `order_by` by related models fields from reverse model + + But you cannot: + * access the related field from reverse model with `related_name` + * even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by`) + * the relation won't be populated in `dict()` and `json()` + * you cannot pass the nested related objects when populating from `dict()` or `json()` (also through `fastapi`). It will be either ignored or raise error depending on `extra` setting in pydantic `Config`. + +## 🐛 Fixes + +* Fix weakref `ReferenceError` error [#118](https://github.com/collerek/ormar/issues/118) +* Fix error raised by Through fields when pydantic `Config.extra="forbid"` is set + +## 💬 Other +* Introduce link to `sqlalchemy-to-ormar` auto-translator for models +* Provide links to fastapi ecosystem libraries that support `ormar` + + # 0.10.2 ## ✨ Features diff --git a/ormar/__init__.py b/ormar/__init__.py index e05928b..7436730 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -75,7 +75,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.10.2" +__version__ = "0.10.3" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/fields/base.py b/ormar/fields/base.py index f96a064..a86a500 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -53,6 +53,8 @@ class BaseField(FieldInfo): "is_relation", None ) # ForeignKeyField + subclasses self.is_through: bool = kwargs.pop("is_through", False) # ThroughFields + self.skip_reverse: bool = kwargs.pop("skip_reverse", False) + self.skip_field: bool = kwargs.pop("skip_field", False) self.owner: Type["Model"] = kwargs.pop("owner", None) self.to: Type["Model"] = kwargs.pop("to", None) diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index b926b83..fe7c812 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -233,9 +233,13 @@ def ForeignKey( # noqa CFQ002 owner = kwargs.pop("owner", None) self_reference = kwargs.pop("self_reference", False) + orders_by = kwargs.pop("orders_by", None) related_orders_by = kwargs.pop("related_orders_by", None) + skip_reverse = kwargs.pop("skip_reverse", False) + skip_field = kwargs.pop("skip_field", False) + validate_not_allowed_fields(kwargs) if to.__class__ == ForwardRef: @@ -274,6 +278,8 @@ def ForeignKey( # noqa CFQ002 is_relation=True, orders_by=orders_by, related_orders_by=related_orders_by, + skip_reverse=skip_reverse, + skip_field=skip_field, ) Field = type("ForeignKey", (ForeignKeyField, BaseField), {}) @@ -312,6 +318,30 @@ class ForeignKeyField(BaseField): """ return self.related_name or self.owner.get_name() + "s" + def default_target_field_name(self, reverse: bool = False) -> str: + """ + Returns default target model name on through model. + :param reverse: flag to grab name without accessing related field + :type reverse: bool + :return: name of the field + :rtype: str + """ + self_rel_prefix = "from_" if not reverse else "to_" + prefix = self_rel_prefix if self.self_reference else "" + return f"{prefix}{self.to.get_name()}" + + def default_source_field_name(self, reverse: bool = False) -> str: + """ + Returns default target model name on through model. + :param reverse: flag to grab name without accessing related field + :type reverse: bool + :return: name of the field + :rtype: str + """ + self_rel_prefix = "to_" if not reverse else "from_" + prefix = self_rel_prefix if self.self_reference else "" + return f"{prefix}{self.owner.get_name()}" + def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None: """ Evaluates the ForwardRef to actual Field based on global and local namespaces diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 829d231..a70f623 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -112,11 +112,16 @@ def ManyToMany( """ related_name = kwargs.pop("related_name", None) nullable = kwargs.pop("nullable", True) + owner = kwargs.pop("owner", None) self_reference = kwargs.pop("self_reference", False) + orders_by = kwargs.pop("orders_by", None) related_orders_by = kwargs.pop("related_orders_by", None) + skip_reverse = kwargs.pop("skip_reverse", False) + skip_field = kwargs.pop("skip_field", False) + if through is not None and through.__class__ != ForwardRef: forbid_through_relations(cast(Type["Model"], through)) @@ -151,6 +156,8 @@ def ManyToMany( is_multi=True, orders_by=orders_by, related_orders_by=related_orders_by, + skip_reverse=skip_reverse, + skip_field=skip_field, ) Field = type("ManyToMany", (ManyToManyField, BaseField), {}) @@ -184,24 +191,6 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro or self.name ) - def default_target_field_name(self) -> str: - """ - Returns default target model name on through model. - :return: name of the field - :rtype: str - """ - prefix = "from_" if self.self_reference else "" - return f"{prefix}{self.to.get_name()}" - - def default_source_field_name(self) -> str: - """ - Returns default target model name on through model. - :return: name of the field - :rtype: str - """ - prefix = "to_" if self.self_reference else "" - return f"{prefix}{self.owner.get_name()}" - def has_unresolved_forward_refs(self) -> bool: """ Verifies if the filed has any ForwardRefs that require updating before the diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index 29cebe6..39e74e3 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -111,6 +111,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None: self_reference=model_field.self_reference, self_reference_primary=model_field.self_reference_primary, orders_by=model_field.related_orders_by, + skip_field=model_field.skip_reverse, ) # register foreign keys on through model model_field = cast("ManyToManyField", model_field) @@ -125,6 +126,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None: owner=model_field.to, self_reference=model_field.self_reference, orders_by=model_field.related_orders_by, + skip_field=model_field.skip_reverse, ) @@ -145,6 +147,7 @@ def register_through_shortcut_fields(model_field: "ManyToManyField") -> None: virtual=True, related_name=model_field.name, owner=model_field.owner, + nullable=True, ) model_field.to.Meta.model_fields[through_name] = Through( @@ -153,6 +156,7 @@ def register_through_shortcut_fields(model_field: "ManyToManyField") -> None: virtual=True, related_name=related_name, owner=model_field.to, + nullable=True, ) diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 34fd454..bc80332 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -90,6 +90,7 @@ def add_cached_properties(new_model: Type["Model"]) -> None: """ new_model._quick_access_fields = quick_access_set new_model._related_names = None + new_model._through_names = None new_model._related_fields = None new_model._pydantic_fields = {name for name in new_model.__fields__} new_model._choices_fields = set() @@ -536,6 +537,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): new_model = populate_meta_tablename_columns_and_pk(name, new_model) populate_meta_sqlalchemy_table_if_required(new_model.Meta) expand_reverse_relationships(new_model) + # TODO: iterate only related fields for field in new_model.Meta.model_fields.values(): register_relation_in_alias_manager(field=field) diff --git a/ormar/models/mixins/relation_mixin.py b/ormar/models/mixins/relation_mixin.py index 151725a..43de0be 100644 --- a/ormar/models/mixins/relation_mixin.py +++ b/ormar/models/mixins/relation_mixin.py @@ -20,6 +20,7 @@ class RelationMixin: Meta: ModelMeta _related_names: Optional[Set] + _through_names: Optional[Set] _related_fields: Optional[List] get_name: Callable @@ -57,19 +58,23 @@ class RelationMixin: return related_fields @classmethod - def extract_through_names(cls) -> Set: + def extract_through_names(cls) -> Set[str]: """ Extracts related fields through names which are shortcuts to through models. :return: set of related through fields names :rtype: Set """ - related_fields = set() - for name in cls.extract_related_names(): - field = cls.Meta.model_fields[name] - if field.is_multi: - related_fields.add(field.through.get_name(lower=True)) - return related_fields + if isinstance(cls._through_names, Set): + return cls._through_names + + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if isinstance(field, BaseField) and field.is_through: + related_names.add(name) + + cls._through_names = related_names + return related_names @classmethod def extract_related_names(cls) -> Set[str]: @@ -89,6 +94,7 @@ class RelationMixin: isinstance(field, BaseField) and field.is_relation and not field.is_through + and not field.skip_field ): related_names.add(name) cls._related_names = related_names diff --git a/ormar/models/model.py b/ormar/models/model.py index 568a787..5e27c69 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -24,7 +24,11 @@ class Model(ModelRow): Meta: ModelMeta def __repr__(self) -> str: # pragma nocover - _repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()} + _repr = { + k: getattr(self, k) + for k, v in self.Meta.model_fields.items() + if not v.skip_field + } return f"{self.__class__.__name__}({str(_repr)})" async def upsert(self: T, **kwargs: Any) -> T: diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index a2a1f1f..356bbbe 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -81,6 +81,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass _orm_id: int _orm_saved: bool _related_names: Optional[Set] + _through_names: Optional[Set] _related_names_hash: str _choices_fields: Optional[Set] _pydantic_fields: Set @@ -165,6 +166,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass for field_to_nullify in excluded: new_kwargs[field_to_nullify] = None + # extract through fields + through_tmp_dict = dict() + for field_name in self.extract_through_names(): + through_tmp_dict[field_name] = new_kwargs.pop(field_name, None) + values, fields_set, validation_error = pydantic.validate_model( self, new_kwargs # type: ignore ) @@ -174,6 +180,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__fields_set__", fields_set) + # add back through fields + new_kwargs.update(through_tmp_dict) + # register the columns models after initialization for related in self.extract_related_names().union(self.extract_through_names()): self.Meta.model_fields[related].expand_relationship( @@ -592,13 +601,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass exclude=self._skip_ellipsis(exclude, field), ) elif nested_model is not None: - dict_instance[field] = nested_model.dict( - relation_map=self._skip_ellipsis( - relation_map, field, default_return=dict() - ), - include=self._skip_ellipsis(include, field), - exclude=self._skip_ellipsis(exclude, field), - ) + try: + dict_instance[field] = nested_model.dict( + relation_map=self._skip_ellipsis( + relation_map, field, default_return=dict() + ), + include=self._skip_ellipsis(include, field), + exclude=self._skip_ellipsis(exclude, field), + ) + except ReferenceError: + dict_instance[field] = None else: dict_instance[field] = None return dict_instance diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 953b43d..c78b6d6 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: # pragma no cover from ormar.relations import Relation from ormar.models import Model, T from ormar.queryset import QuerySet - from ormar import RelationType + from ormar import RelationType, ForeignKeyField else: T = TypeVar("T", bound="Model") @@ -251,7 +251,7 @@ class QuerysetProxy(Generic[T]): owner_column = self._owner.get_name() else: queryset = ormar.QuerySet(model_cls=self.relation.to) # type: ignore - owner_column = self.related_field.name + owner_column = self.related_field_name kwargs = {owner_column: self._owner} self._clean_items_on_load() if keep_reversed and self.type_ == ormar.RelationType.REVERSE: @@ -367,7 +367,7 @@ class QuerysetProxy(Generic[T]): """ through_kwargs = kwargs.pop(self.through_model_name, {}) if self.type_ == ormar.RelationType.REVERSE: - kwargs[self.related_field.name] = self._owner + kwargs[self.related_field_name] = self._owner created = await self.queryset.create(**kwargs) self._register_related(created) if self.type_ == ormar.RelationType.MULTIPLE: diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 900f8f3..be3258c 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -124,15 +124,14 @@ class RelationProxy(Generic[T], list): :rtype: QuerySet """ related_field_name = self.related_field_name - related_field = self.relation.to.Meta.model_fields[related_field_name] pkname = self._owner.get_column_alias(self._owner.Meta.pkname) self._check_if_model_saved() - kwargs = {f"{related_field.name}__{pkname}": self._owner.pk} + kwargs = {f"{related_field_name}__{pkname}": self._owner.pk} queryset = ( ormar.QuerySet( model_cls=self.relation.to, proxy_source_model=self._owner.__class__ ) - .select_related(related_field.name) + .select_related(related_field_name) .filter(**kwargs) ) return queryset @@ -168,11 +167,12 @@ class RelationProxy(Generic[T], list): super().remove(item) relation_name = self.related_field_name relation = item._orm._get(relation_name) - if relation is None: # pragma nocover - raise ValueError( - f"{self._owner.get_name()} does not have relation {relation_name}" - ) - relation.remove(self._owner) + # if relation is None: # pragma nocover + # raise ValueError( + # f"{self._owner.get_name()} does not have relation {relation_name}" + # ) + if relation: + relation.remove(self._owner) self.relation.remove(item) if self.type_ == ormar.RelationType.MULTIPLE: await self.queryset_proxy.delete_through_instance(item) @@ -211,7 +211,7 @@ class RelationProxy(Generic[T], list): self._check_if_model_saved() if self.type_ == ormar.RelationType.MULTIPLE: await self.queryset_proxy.create_through_instance(item, **kwargs) - setattr(item, relation_name, self._owner) + setattr(self._owner, self.field_name, item) else: setattr(item, relation_name, self._owner) await item.update() diff --git a/tests/test_fastapi/test_inheritance_concrete_fastapi.py b/tests/test_fastapi/test_inheritance_concrete_fastapi.py index a4a8310..e547ad2 100644 --- a/tests/test_fastapi/test_inheritance_concrete_fastapi.py +++ b/tests/test_fastapi/test_inheritance_concrete_fastapi.py @@ -1,4 +1,5 @@ import datetime +from typing import List import pytest import sqlalchemy @@ -59,6 +60,12 @@ async def get_bus(item_id: int): return bus +@app.get("/buses/", response_model=List[Bus]) +async def get_buses(): + buses = await Bus.objects.select_related(["owner", "co_owner"]).all() + return buses + + @app.post("/trucks/", response_model=Truck) async def create_truck(truck: Truck): await truck.save() @@ -84,6 +91,12 @@ async def add_bus_coowner(item_id: int, person: Person): return bus +@app.get("/buses2/", response_model=List[Bus2]) +async def get_buses2(): + buses = await Bus2.objects.select_related(["owner", "co_owners"]).all() + return buses + + @app.post("/trucks2/", response_model=Truck2) async def create_truck2(truck: Truck2): await truck.save() @@ -172,6 +185,10 @@ def test_inheritance_with_relation(): assert unicorn2.co_owner.name == "Joe" assert unicorn2.max_persons == 50 + buses = [Bus(**x) for x in client.get("/buses/").json()] + assert len(buses) == 1 + assert buses[0].name == "Unicorn" + def test_inheritance_with_m2m_relation(): client = TestClient(app) @@ -217,3 +234,7 @@ def test_inheritance_with_m2m_relation(): assert shelby.co_owners[0] == alex assert shelby.co_owners[1] == joe assert shelby.max_capacity == 2000 + + buses = [Bus2(**x) for x in client.get("/buses2/").json()] + assert len(buses) == 1 + assert buses[0].name == "Unicorn" diff --git a/tests/test_fastapi/test_nested_saving.py b/tests/test_fastapi/test_nested_saving.py new file mode 100644 index 0000000..bb388e7 --- /dev/null +++ b/tests/test_fastapi/test_nested_saving.py @@ -0,0 +1,106 @@ +import json +from typing import List, Optional + +import databases +import pytest +import sqlalchemy +from fastapi import FastAPI +from starlette.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +class Department(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + department_name: str = ormar.String(max_length=100) + + +class Course(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + course_name: str = ormar.String(max_length=100) + completed: bool = ormar.Boolean() + department: Optional[Department] = ormar.ForeignKey(Department) + + +# create db and tables +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@app.post("/DepartmentWithCourses/", response_model=Department) +async def create_department(department: Department): + # there is no save all - you need to split into save and save_related + await department.save() + await department.save_related(follow=True, save_all=True) + return department + + +@app.get("/DepartmentsAll/", response_model=List[Department]) +async def get_Courses(): + # if you don't provide default name it related model name + s so courses not course + departmentall = await Department.objects.select_related("courses").all() + return departmentall + + +def test_saving_related_in_fastapi(): + client = TestClient(app) + with client as client: + payload = { + "department_name": "Ormar", + "courses": [ + {"course_name": "basic1", "completed": True}, + {"course_name": "basic2", "completed": True}, + ], + } + response = client.post("/DepartmentWithCourses/", data=json.dumps(payload)) + department = Department(**response.json()) + + assert department.id is not None + assert len(department.courses) == 2 + assert department.department_name == "Ormar" + assert department.courses[0].course_name == "basic1" + assert department.courses[0].completed + assert department.courses[1].course_name == "basic2" + assert department.courses[1].completed + + response = client.get("/DepartmentsAll/") + departments = [Department(**x) for x in response.json()] + assert departments[0].id is not None + assert len(departments[0].courses) == 2 + assert departments[0].department_name == "Ormar" + assert departments[0].courses[0].course_name == "basic1" + assert departments[0].courses[0].completed + assert departments[0].courses[1].course_name == "basic2" + assert departments[0].courses[1].completed diff --git a/tests/test_fastapi/test_skip_reverse_models.py b/tests/test_fastapi/test_skip_reverse_models.py new file mode 100644 index 0000000..2c767d1 --- /dev/null +++ b/tests/test_fastapi/test_skip_reverse_models.py @@ -0,0 +1,148 @@ +import json +from typing import List, Optional + +import databases +import pytest +import sqlalchemy +from fastapi import FastAPI +from starlette.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +class BaseMeta(ormar.ModelMeta): + database = database + metadata = metadata + + +class Author(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + first_name: str = ormar.String(max_length=80) + last_name: str = ormar.String(max_length=80) + + +class Category(ormar.Model): + class Meta(BaseMeta): + tablename = "categories" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=40) + + +class Post(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + categories = ormar.ManyToMany(Category, skip_reverse=True) + author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@app.post("/categories/", response_model=Category) +async def create_category(category: Category): + await category.save() + await category.save_related(follow=True, save_all=True) + return category + + +@app.post("/posts/", response_model=Post) +async def create_post(post: Post): + if post.author: + await post.author.save() + await post.save() + await post.save_related(follow=True, save_all=True) + for category in [cat for cat in post.categories]: + await post.categories.add(category) + return post + + +@app.get("/categories/", response_model=List[Category]) +async def get_categories(): + return await Category.objects.select_related("posts").all() + + +@app.get("/posts/", response_model=List[Post]) +async def get_posts(): + posts = await Post.objects.select_related(["categories", "author"]).all() + return posts + + +def test_queries(): + client = TestClient(app) + with client as client: + right_category = {"name": "Test category"} + wrong_category = {"name": "Test category2", "posts": [{"title": "Test Post"}]} + + # cannot add posts if skipped, will be ignored (with extra=ignore by default) + response = client.post("/categories/", data=json.dumps(wrong_category)) + assert response.status_code == 200 + response = client.get("/categories/") + assert response.status_code == 200 + assert not "posts" in response.json() + categories = [Category(**x) for x in response.json()] + assert categories[0] is not None + assert categories[0].name == "Test category2" + + response = client.post("/categories/", data=json.dumps(right_category)) + assert response.status_code == 200 + + response = client.get("/categories/") + assert response.status_code == 200 + categories = [Category(**x) for x in response.json()] + assert categories[1] is not None + assert categories[1].name == "Test category" + + right_post = { + "title": "ok post", + "author": {"first_name": "John", "last_name": "Smith"}, + "categories": [{"name": "New cat"}], + } + response = client.post("/posts/", data=json.dumps(right_post)) + assert response.status_code == 200 + + Category.__config__.extra = "allow" + response = client.get("/posts/") + assert response.status_code == 200 + posts = [Post(**x) for x in response.json()] + assert posts[0].title == "ok post" + assert posts[0].author.first_name == "John" + assert posts[0].categories[0].name == "New cat" + + wrong_category = {"name": "Test category3", "posts": [{"title": "Test Post"}]} + + # cannot add posts if skipped, will be error with extra forbid + Category.__config__.extra = "forbid" + response = client.post("/categories/", data=json.dumps(wrong_category)) + assert response.status_code == 422 diff --git a/tests/test_fastapi/test_wekref_exclusion.py b/tests/test_fastapi/test_wekref_exclusion.py index a1140f7..f3b6408 100644 --- a/tests/test_fastapi/test_wekref_exclusion.py +++ b/tests/test_fastapi/test_wekref_exclusion.py @@ -123,6 +123,16 @@ async def get_test_5(thing_id: UUID): return await Thing.objects.all(other_thing__id=thing_id) +@app.get( + "/test/error", response_model=List[Thing], response_model_exclude={"other_thing"} +) +async def get_weakref(): + ots = await OtherThing.objects.all() + ot = ots[0] + ts = await ot.things.all() + return ts + + def test_endpoints(): client = TestClient(app) with client: @@ -145,3 +155,7 @@ def test_endpoints(): resp5 = client.get(f"/test/5/{ot.id}") assert resp5.status_code == 200 assert len(resp5.json()) == 3 + + resp6 = client.get("/test/error") + assert resp6.status_code == 200 + assert len(resp6.json()) == 3 diff --git a/tests/test_relations/test_skipping_reverse.py b/tests/test_relations/test_skipping_reverse.py new file mode 100644 index 0000000..da1939d --- /dev/null +++ b/tests/test_relations/test_skipping_reverse.py @@ -0,0 +1,223 @@ +from typing import List, Optional + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + database = database + metadata = metadata + + +class Author(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + first_name: str = ormar.String(max_length=80) + last_name: str = ormar.String(max_length=80) + + +class Category(ormar.Model): + class Meta(BaseMeta): + tablename = "categories" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=40) + + +class Post(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + categories: Optional[List[Category]] = ormar.ManyToMany(Category, skip_reverse=True) + author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.fixture(scope="function") +async def cleanup(): + yield + async with database: + PostCategory = Post.Meta.model_fields["categories"].through + await PostCategory.objects.delete(each=True) + await Post.objects.delete(each=True) + await Category.objects.delete(each=True) + await Author.objects.delete(each=True) + + +def test_model_definition(): + category = Category(name="Test") + author = Author(first_name="Test", last_name="Author") + post = Post(title="Test Post", author=author) + post.categories = category + + assert post.categories[0] == category + assert post.author == author + + with pytest.raises(AttributeError): + assert author.posts + + with pytest.raises(AttributeError): + assert category.posts + + assert "posts" not in category._orm + + +@pytest.mark.asyncio +async def test_assigning_related_objects(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + + # Add a category to a post. + await post.categories.add(news) + # other way is disabled + with pytest.raises(AttributeError): + await news.posts.add(post) + + assert await post.categories.get_or_none(name="no exist") is None + assert await post.categories.get_or_none(name="News") == news + + # Creating columns object from instance: + await post.categories.create(name="Tips") + assert len(post.categories) == 2 + + post_categories = await post.categories.all() + assert len(post_categories) == 2 + + category = await Category.objects.select_related("posts").get(name="News") + with pytest.raises(AttributeError): + assert category.posts + + +@pytest.mark.asyncio +async def test_quering_of_related_model_works_but_no_result(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + + await post.categories.add(news) + + post_categories = await post.categories.all() + assert len(post_categories) == 1 + + assert "posts" not in post.dict().get("categories", [])[0] + + assert news == await post.categories.get(name="News") + + posts_about_python = await Post.objects.filter(categories__name="python").all() + assert len(posts_about_python) == 0 + + # relation not in dict + category = ( + await Category.objects.select_related("posts") + .filter(posts__author=guido) + .get() + ) + assert category == news + assert "posts" not in category.dict() + + # relation not in json + category2 = ( + await Category.objects.select_related("posts") + .filter(posts__author__first_name="Guido") + .get() + ) + assert category2 == news + assert "posts" not in category2.json() + + assert "posts" not in Category.schema().get("properties") + + +@pytest.mark.asyncio +async def test_removal_of_the_relations(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + post = await Post.objects.create(title="Hello, M2M", author=guido) + news = await Category.objects.create(name="News") + await post.categories.add(news) + assert len(await post.categories.all()) == 1 + await post.categories.remove(news) + assert len(await post.categories.all()) == 0 + + with pytest.raises(AttributeError): + await news.posts.add(post) + with pytest.raises(AttributeError): + await news.posts.remove(post) + + await post.categories.add(news) + await post.categories.clear() + assert len(await post.categories.all()) == 0 + + await post.categories.add(news) + await news.delete() + assert len(await post.categories.all()) == 0 + + +@pytest.mark.asyncio +async def test_selecting_related(cleanup): + async with database: + guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum") + guido2 = await Author.objects.create( + first_name="Guido2", last_name="Van Rossum" + ) + + post = await Post.objects.create(title="Hello, M2M", author=guido) + post2 = await Post.objects.create(title="Bye, M2M", author=guido2) + + news = await Category.objects.create(name="News") + recent = await Category.objects.create(name="Recent") + + await post.categories.add(news) + await post.categories.add(recent) + await post2.categories.add(recent) + + assert len(await post.categories.all()) == 2 + assert (await post.categories.limit(1).all())[0] == news + assert (await post.categories.offset(1).limit(1).all())[0] == recent + assert await post.categories.first() == news + assert await post.categories.exists() + + # still can order + categories = ( + await Category.objects.select_related("posts") + .order_by("posts__title") + .all() + ) + assert categories[0].name == "Recent" + assert categories[1].name == "News" + + # still can filter + categories = await Category.objects.filter(posts__title="Bye, M2M").all() + assert categories[0].name == "Recent" + assert len(categories) == 1 + + # same for reverse fk + authors = ( + await Author.objects.select_related("posts").order_by("posts__title").all() + ) + assert authors[0].first_name == "Guido2" + assert authors[1].first_name == "Guido" + + authors = await Author.objects.filter(posts__title="Bye, M2M").all() + assert authors[0].first_name == "Guido2" + assert len(authors) == 1 From 6780c9de8a5cc82d809b0a91456d9828bba4d4d6 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 12 Apr 2021 10:40:29 +0200 Subject: [PATCH 2/6] fix private attributes initialization --- docs/releases.md | 1 + ormar/models/newbasemodel.py | 34 +++++++++++-------- ormar/relations/querysetproxy.py | 2 +- .../test_pydantic_private_attributes.py | 34 +++++++++++++++++++ 4 files changed, 56 insertions(+), 15 deletions(-) create mode 100644 tests/test_model_definition/test_pydantic_private_attributes.py diff --git a/docs/releases.md b/docs/releases.md index 18d5a8d..69e725b 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -18,6 +18,7 @@ * Fix weakref `ReferenceError` error [#118](https://github.com/collerek/ormar/issues/118) * Fix error raised by Through fields when pydantic `Config.extra="forbid"` is set +* Fix bug with `pydantic.PrivateAttr` not being initialized at `__init__` [#149](https://github.com/collerek/ormar/issues/149) ## 💬 Other * Introduce link to `sqlalchemy-to-ormar` auto-translator for models diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 356bbbe..af69e38 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -189,6 +189,10 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass new_kwargs.get(related), self, to_register=True, ) + if hasattr(self, "_init_private_attributes"): + # introduced in pydantic 1.7 + self._init_private_attributes() + def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 """ Overwrites setattr in object to allow for special behaviour of certain params. @@ -292,6 +296,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass value = object.__getattribute__(self, "__dict__").get(item, None) value = object.__getattribute__(self, "_convert_json")(item, value, "loads") return value + return object.__getattribute__(self, item) # pragma: no cover def _verify_model_can_be_initialized(self) -> None: @@ -590,18 +595,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass for field in fields: if not relation_map or field not in relation_map: continue - nested_model = getattr(self, field) - if isinstance(nested_model, MutableSequence): - dict_instance[field] = self._extract_nested_models_from_list( - relation_map=self._skip_ellipsis( # type: ignore - relation_map, field, default_return=dict() - ), - models=nested_model, - include=self._skip_ellipsis(include, field), - exclude=self._skip_ellipsis(exclude, field), - ) - elif nested_model is not None: - try: + try: + nested_model = getattr(self, field) + if isinstance(nested_model, MutableSequence): + dict_instance[field] = self._extract_nested_models_from_list( + relation_map=self._skip_ellipsis( # type: ignore + relation_map, field, default_return=dict() + ), + models=nested_model, + include=self._skip_ellipsis(include, field), + exclude=self._skip_ellipsis(exclude, field), + ) + elif nested_model is not None: + dict_instance[field] = nested_model.dict( relation_map=self._skip_ellipsis( relation_map, field, default_return=dict() @@ -609,9 +615,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass include=self._skip_ellipsis(include, field), exclude=self._skip_ellipsis(exclude, field), ) - except ReferenceError: + else: dict_instance[field] = None - else: + except ReferenceError: dict_instance[field] = None return dict_instance diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index c78b6d6..4578414 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: # pragma no cover from ormar.relations import Relation from ormar.models import Model, T from ormar.queryset import QuerySet - from ormar import RelationType, ForeignKeyField + from ormar import RelationType else: T = TypeVar("T", bound="Model") diff --git a/tests/test_model_definition/test_pydantic_private_attributes.py b/tests/test_model_definition/test_pydantic_private_attributes.py new file mode 100644 index 0000000..2765f60 --- /dev/null +++ b/tests/test_model_definition/test_pydantic_private_attributes.py @@ -0,0 +1,34 @@ +from typing import List + +import databases +import sqlalchemy +from pydantic import PrivateAttr + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class Subscription(ormar.Model): + class Meta(BaseMeta): + tablename = "subscriptions" + + id: int = ormar.Integer(primary_key=True) + stripe_subscription_id: str = ormar.String(nullable=False, max_length=256) + + _add_payments: List[str] = PrivateAttr(default_factory=list) + + def add_payment(self, payment: str): + self._add_payments.append(payment) + + +def test_private_attribute(): + sub = Subscription(stripe_subscription_id="2312312sad231") + sub.add_payment("test") From 854b27947a0593af31af61abf4d0ff13d7d34091 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 12 Apr 2021 17:39:42 +0200 Subject: [PATCH 3/6] modify save_related to be able to save whole tree from dict - including reverse fk and m2m relations - with correct order of saving --- docs/releases.md | 11 + ormar/models/mixins/relation_mixin.py | 7 +- ormar/models/mixins/save_mixin.py | 133 ++++++++- ormar/models/model.py | 106 ++++---- ormar/relations/querysetproxy.py | 17 +- tests/test_model_methods/test_save_related.py | 2 +- .../test_save_related_from_dict.py | 256 ++++++++++++++++++ 7 files changed, 474 insertions(+), 58 deletions(-) create mode 100644 tests/test_model_methods/test_save_related_from_dict.py diff --git a/docs/releases.md b/docs/releases.md index 69e725b..b9becc1 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -13,6 +13,17 @@ * even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by`) * the relation won't be populated in `dict()` and `json()` * you cannot pass the nested related objects when populating from `dict()` or `json()` (also through `fastapi`). It will be either ignored or raise error depending on `extra` setting in pydantic `Config`. +* `Model.save_related()` now can save whole data tree in once [#148](https://github.com/collerek/ormar/discussions/148) + meaning: + * it knows if it should save main `Model` or related `Model` first to preserve the relation + * it saves main `Model` if + * it's not `saved`, + * has no `pk` value + * or `save_all=True` flag is set + + in those cases you don't have to split save into two calls (`save()` and `save_related()`) + * it supports also `ManyToMany` relations + * it supports also optional `Through` model values for m2m relations ## 🐛 Fixes diff --git a/ormar/models/mixins/relation_mixin.py b/ormar/models/mixins/relation_mixin.py index 43de0be..6a71382 100644 --- a/ormar/models/mixins/relation_mixin.py +++ b/ormar/models/mixins/relation_mixin.py @@ -4,9 +4,10 @@ from typing import ( Optional, Set, TYPE_CHECKING, + cast, ) -from ormar import BaseField +from ormar import BaseField, ForeignKeyField from ormar.models.traversible import NodeList @@ -39,7 +40,7 @@ class RelationMixin: return self_fields @classmethod - def extract_related_fields(cls) -> List: + def extract_related_fields(cls) -> List["ForeignKeyField"]: """ Returns List of ormar Fields for all relations declared on a model. List is cached in cls._related_fields for quicker access. @@ -52,7 +53,7 @@ class RelationMixin: related_fields = [] for name in cls.extract_related_names().union(cls.extract_through_names()): - related_fields.append(cls.Meta.model_fields[name]) + related_fields.append(cast("ForeignKeyField", cls.Meta.model_fields[name])) cls._related_fields = related_fields return related_fields diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 5ade49f..52826b7 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -1,5 +1,5 @@ import uuid -from typing import Dict, Optional, Set, TYPE_CHECKING +from typing import Callable, Collection, Dict, Optional, Set, TYPE_CHECKING, cast import ormar from ormar.exceptions import ModelPersistenceError @@ -7,6 +7,9 @@ from ormar.models.helpers.validation import validate_choices from ormar.models.mixins import AliasMixin from ormar.models.mixins.relation_mixin import RelationMixin +if TYPE_CHECKING: # pragma: no cover + from ormar import ForeignKeyField, Model + class SavePrepareMixin(RelationMixin, AliasMixin): """ @@ -15,6 +18,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): if TYPE_CHECKING: # pragma: nocover _choices_fields: Optional[Set] + _skip_ellipsis: Callable @classmethod def prepare_model_to_save(cls, new_kwargs: dict) -> dict: @@ -170,3 +174,130 @@ class SavePrepareMixin(RelationMixin, AliasMixin): if field_name in new_kwargs and field_name in cls._choices_fields: validate_choices(field=field, value=new_kwargs.get(field_name)) return new_kwargs + + @staticmethod + async def _upsert_model( + instance: "Model", + save_all: bool, + previous_model: Optional["Model"], + relation_field: Optional["ForeignKeyField"], + update_count: int, + ) -> int: + """ + Method updates given instance if: + + * instance is not saved or + * instance have no pk or + * save_all=True flag is set + + and instance is not __pk_only__. + + If relation leading to instance is a ManyToMany also the through model is saved + + :param instance: current model to upsert + :type instance: Model + :param save_all: flag if all models should be saved or only not saved ones + :type save_all: bool + :param relation_field: field with relation + :type relation_field: Optional[ForeignKeyField] + :param previous_model: previous model from which method came + :type previous_model: Model + :param update_count: no of updated models + :type update_count: int + :return: no of updated models + :rtype: int + """ + if ( + save_all or not instance.pk or not instance.saved + ) and not instance.__pk_only__: + await instance.upsert() + if relation_field and relation_field.is_multi: + await instance._upsert_through_model( + instance=instance, + relation_field=relation_field, + previous_model=cast("Model", previous_model), + ) + update_count += 1 + return update_count + + @staticmethod + async def _upsert_through_model( + instance: "Model", + previous_model: "Model", + relation_field: Optional["ForeignKeyField"], + ) -> None: + """ + Upsert through model for m2m relation. + + :param instance: current model to upsert + :type instance: Model + :param relation_field: field with relation + :type relation_field: Optional[ForeignKeyField] + :param previous_model: previous model from which method came + :type previous_model: Model + """ + through_name = previous_model.Meta.model_fields[ + relation_field.name + ].through.get_name() + through = getattr(instance, through_name) + if through: + through_dict = through.dict(exclude=through.extract_related_names()) + else: + through_dict = {} + await getattr( + previous_model, relation_field.name + ).queryset_proxy.upsert_through_instance(instance, **through_dict) + + async def _update_relation_list( + self, + fields_list: Collection["ForeignKeyField"], + follow: bool, + save_all: bool, + relation_map: Dict, + update_count: int, + ) -> int: + """ + Internal method used in save_related to follow deeper from + related models and update numbers of updated related instances. + + :type save_all: flag if all models should be saved + :type save_all: bool + :param fields_list: list of ormar fields to follow and save + :type fields_list: Collection["ForeignKeyField"] + :param relation_map: map of relations to follow + :type relation_map: Dict + :param follow: flag to trigger deep save - + by default only directly related models are saved + with follow=True also related models of related models are saved + :type follow: bool + :param update_count: internal parameter for recursive calls - + number of updated instances + :type update_count: int + :return: tuple of update count and visited + :rtype: int + """ + for field in fields_list: + value = getattr(self, field.name) or [] + if not isinstance(value, list): + value = [value] + for val in value: + if follow: + update_count = await val.save_related( + follow=follow, + save_all=save_all, + relation_map=self._skip_ellipsis( # type: ignore + relation_map, field.name, default_return={} + ), + update_count=update_count, + previous_model=self, + relation_field=field, + ) + else: + update_count = await val._upsert_model( + instance=val, + save_all=save_all, + previous_model=self, + relation_field=field, + update_count=update_count, + ) + return update_count diff --git a/ormar/models/model.py b/ormar/models/model.py index 5e27c69..7c5f834 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -2,6 +2,7 @@ from typing import ( Any, Dict, List, + Optional, Set, TYPE_CHECKING, TypeVar, @@ -17,6 +18,9 @@ from ormar.queryset.utils import subtract_dict, translate_list_to_dict T = TypeVar("T", bound="Model") +if TYPE_CHECKING: # pragma: no cover + from ormar import ForeignKeyField + class Model(ModelRow): __abstract__ = False @@ -110,6 +114,8 @@ class Model(ModelRow): relation_map: Dict = None, exclude: Union[Set, Dict] = None, update_count: int = 0, + previous_model: "Model" = None, + relation_field: Optional["ForeignKeyField"] = None, ) -> int: """ Triggers a upsert method on all related models @@ -126,6 +132,10 @@ class Model(ModelRow): Model A but will never follow into Model C. Nested relations of those kind need to be persisted manually. + :param relation_field: field with relation leading to this model + :type relation_field: Optional[ForeignKeyField] + :param previous_model: previous model from which method came + :type previous_model: Model :param exclude: items to exclude during saving of relations :type exclude: Union[Set, Dict] :param relation_map: map of relations to follow @@ -151,61 +161,53 @@ class Model(ModelRow): exclude = translate_list_to_dict(exclude) relation_map = subtract_dict(relation_map, exclude or {}) - for related in self.extract_related_names(): - if relation_map and related in relation_map: - value = getattr(self, related) - if value: - update_count = await self._update_and_follow( - value=value, - follow=follow, - save_all=save_all, - relation_map=self._skip_ellipsis( # type: ignore - relation_map, related, default_return={} - ), - update_count=update_count, - ) - return update_count + if relation_map: + fields_to_visit = { + field + for field in self.extract_related_fields() + if field.name in relation_map + } + pre_save = { + field + for field in fields_to_visit + if not field.virtual and not field.is_multi + } - @staticmethod - async def _update_and_follow( - value: Union["Model", List["Model"]], - follow: bool, - save_all: bool, - relation_map: Dict, - update_count: int, - ) -> int: - """ - Internal method used in save_related to follow related models and update numbers - of updated related instances. + update_count = await self._update_relation_list( + fields_list=pre_save, + follow=follow, + save_all=save_all, + relation_map=relation_map, + update_count=update_count, + ) - :param value: Model to follow - :type value: Model - :param relation_map: map of relations to follow - :type relation_map: Dict - :param follow: flag to trigger deep save - - by default only directly related models are saved - with follow=True also related models of related models are saved - :type follow: bool - :param update_count: internal parameter for recursive calls - - number of updated instances - :type update_count: int - :return: tuple of update count and visited - :rtype: int - """ - if not isinstance(value, list): - value = [value] + update_count = await self._upsert_model( + instance=self, + save_all=save_all, + previous_model=previous_model, + relation_field=relation_field, + update_count=update_count, + ) + + post_save = fields_to_visit - pre_save + + update_count = await self._update_relation_list( + fields_list=post_save, + follow=follow, + save_all=save_all, + relation_map=relation_map, + update_count=update_count, + ) + + else: + update_count = await self._upsert_model( + instance=self, + save_all=save_all, + previous_model=previous_model, + relation_field=relation_field, + update_count=update_count, + ) - for val in value: - if (not val.saved or save_all) and not val.__pk_only__: - await val.upsert() - update_count += 1 - if follow: - update_count = await val.save_related( - follow=follow, - save_all=save_all, - relation_map=relation_map, - update_count=update_count, - ) return update_count async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T: diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 4578414..f60f263 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -16,7 +16,7 @@ from typing import ( # noqa: I100, I201 ) import ormar # noqa: I100, I202 -from ormar.exceptions import ModelPersistenceError, QueryDefinitionError +from ormar.exceptions import ModelPersistenceError, NoMatch, QueryDefinitionError if TYPE_CHECKING: # pragma no cover from ormar.relations import Relation @@ -152,6 +152,21 @@ class QuerysetProxy(Generic[T]): through_model = await model_cls.objects.get(**rel_kwargs) await through_model.update(**kwargs) + async def upsert_through_instance(self, child: "T", **kwargs: Any) -> None: + """ + Updates a through model instance in the database for m2m relations if + it already exists, else creates one. + + :param kwargs: dict of additional keyword arguments for through instance + :type kwargs: Any + :param child: child model instance + :type child: Model + """ + try: + await self.update_through_instance(child=child, **kwargs) + except NoMatch: + await self.create_through_instance(child=child, **kwargs) + async def delete_through_instance(self, child: "T") -> None: """ Removes through model instance from the database for m2m relations. diff --git a/tests/test_model_methods/test_save_related.py b/tests/test_model_methods/test_save_related.py index 774d167..207c53d 100644 --- a/tests/test_model_methods/test_save_related.py +++ b/tests/test_model_methods/test_save_related.py @@ -119,7 +119,7 @@ async def test_saving_many_to_many(): assert count == 0 count = await hq.save_related(save_all=True) - assert count == 2 + assert count == 3 hq.nicks[0].name = "Kabucha" hq.nicks[1].name = "Kabucha2" diff --git a/tests/test_model_methods/test_save_related_from_dict.py b/tests/test_model_methods/test_save_related_from_dict.py new file mode 100644 index 0000000..a545092 --- /dev/null +++ b/tests/test_model_methods/test_save_related_from_dict.py @@ -0,0 +1,256 @@ +from typing import List + +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class CringeLevel(ormar.Model): + class Meta: + tablename = "levels" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class NickName(ormar.Model): + class Meta: + tablename = "nicks" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + is_lame: bool = ormar.Boolean(nullable=True) + level: CringeLevel = ormar.ForeignKey(CringeLevel) + + +class NicksHq(ormar.Model): + class Meta: + tablename = "nicks_x_hq" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + new_field: str = ormar.String(max_length=200, nullable=True) + + +class HQ(ormar.Model): + class Meta: + tablename = "hqs" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="hq_name") + nicks: List[NickName] = ormar.ManyToMany(NickName, through=NicksHq) + + +class Company(ormar.Model): + class Meta: + tablename = "companies" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, nullable=False, name="company_name") + founded: int = ormar.Integer(nullable=True) + hq: HQ = ormar.ForeignKey(HQ, related_name="companies") + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_saving_related_reverse_fk(): + async with database: + async with database.transaction(force_rollback=True): + payload = {"companies": [{"name": "Banzai"}], "name": "Main"} + hq = HQ(**payload) + count = await hq.save_related(follow=True, save_all=True) + assert count == 2 + + hq_check = await HQ.objects.select_related("companies").get() + assert hq_check.pk is not None + assert hq_check.name == "Main" + assert len(hq_check.companies) == 1 + assert hq_check.companies[0].name == "Banzai" + assert hq_check.companies[0].pk is not None + + +@pytest.mark.asyncio +async def test_saving_related_reverse_fk_multiple(): + async with database: + async with database.transaction(force_rollback=True): + payload = { + "companies": [{"name": "Banzai"}, {"name": "Yamate"}], + "name": "Main", + } + hq = HQ(**payload) + count = await hq.save_related(follow=True, save_all=True) + assert count == 3 + + hq_check = await HQ.objects.select_related("companies").get() + assert hq_check.pk is not None + assert hq_check.name == "Main" + assert len(hq_check.companies) == 2 + assert hq_check.companies[0].name == "Banzai" + assert hq_check.companies[0].pk is not None + assert hq_check.companies[1].name == "Yamate" + assert hq_check.companies[1].pk is not None + + +@pytest.mark.asyncio +async def test_saving_related_fk(): + async with database: + async with database.transaction(force_rollback=True): + payload = {"hq": {"name": "Main"}, "name": "Banzai"} + comp = Company(**payload) + count = await comp.save_related(follow=True, save_all=True) + assert count == 2 + + comp_check = await Company.objects.select_related("hq").get() + assert comp_check.pk is not None + assert comp_check.name == "Banzai" + assert comp_check.hq.name == "Main" + assert comp_check.hq.pk is not None + + +@pytest.mark.asyncio +async def test_saving_many_to_many_wo_through(): + async with database: + async with database.transaction(force_rollback=True): + payload = { + "name": "Main", + "nicks": [ + {"name": "Bazinga0", "is_lame": False}, + {"name": "Bazinga20", "is_lame": True}, + ], + } + + hq = HQ(**payload) + count = await hq.save_related() + assert count == 3 + + hq_check = await HQ.objects.select_related("nicks").get() + assert hq_check.pk is not None + assert len(hq_check.nicks) == 2 + assert hq_check.nicks[0].name == "Bazinga0" + assert hq_check.nicks[1].name == "Bazinga20" + + +@pytest.mark.asyncio +async def test_saving_many_to_many_with_through(): + async with database: + async with database.transaction(force_rollback=True): + async with database.transaction(force_rollback=True): + payload = { + "name": "Main", + "nicks": [ + { + "name": "Bazinga0", + "is_lame": False, + "nickshq": {"new_field": "test"}, + }, + { + "name": "Bazinga20", + "is_lame": True, + "nickshq": {"new_field": "test2"}, + }, + ], + } + + hq = HQ(**payload) + count = await hq.save_related() + assert count == 3 + + hq_check = await HQ.objects.select_related("nicks").get() + assert hq_check.pk is not None + assert len(hq_check.nicks) == 2 + assert hq_check.nicks[0].name == "Bazinga0" + assert hq_check.nicks[0].nickshq.new_field == "test" + assert hq_check.nicks[1].name == "Bazinga20" + assert hq_check.nicks[1].nickshq.new_field == "test2" + + +@pytest.mark.asyncio +async def test_saving_nested_with_m2m_and_rev_fk(): + async with database: + async with database.transaction(force_rollback=True): + payload = { + "name": "Main", + "nicks": [ + {"name": "Bazinga0", "is_lame": False, "level": {"name": "High"}}, + {"name": "Bazinga20", "is_lame": True, "level": {"name": "Low"}}, + ], + } + + hq = HQ(**payload) + count = await hq.save_related(follow=True, save_all=True) + assert count == 5 + + hq_check = await HQ.objects.select_related("nicks__level").get() + assert hq_check.pk is not None + assert len(hq_check.nicks) == 2 + assert hq_check.nicks[0].name == "Bazinga0" + assert hq_check.nicks[0].level.name == "High" + assert hq_check.nicks[1].name == "Bazinga20" + assert hq_check.nicks[1].level.name == "Low" + + +@pytest.mark.asyncio +async def test_saving_nested_with_m2m_and_rev_fk_and_through(): + async with database: + async with database.transaction(force_rollback=True): + payload = { + "hq": { + "name": "Yoko", + "nicks": [ + { + "name": "Bazinga0", + "is_lame": False, + "nickshq": {"new_field": "test"}, + "level": {"name": "High"}, + }, + { + "name": "Bazinga20", + "is_lame": True, + "nickshq": {"new_field": "test2"}, + "level": {"name": "Low"}, + }, + ], + }, + "name": "Main", + } + + company = Company(**payload) + count = await company.save_related(follow=True, save_all=True) + assert count == 6 + + company_check = await Company.objects.select_related( + "hq__nicks__level" + ).get() + assert company_check.pk is not None + assert company_check.name == "Main" + assert company_check.hq.name == "Yoko" + assert len(company_check.hq.nicks) == 2 + assert company_check.hq.nicks[0].name == "Bazinga0" + assert company_check.hq.nicks[0].nickshq.new_field == "test" + assert company_check.hq.nicks[0].level.name == "High" + assert company_check.hq.nicks[1].name == "Bazinga20" + assert company_check.hq.nicks[1].level.name == "Low" + assert company_check.hq.nicks[1].nickshq.new_field == "test2" From d20198e6e12580b2b487d19c0dbf0b2fd9c36469 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 12 Apr 2021 17:45:19 +0200 Subject: [PATCH 4/6] fix mypy --- ormar/models/mixins/save_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 52826b7..5ff29ad 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -224,7 +224,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): async def _upsert_through_model( instance: "Model", previous_model: "Model", - relation_field: Optional["ForeignKeyField"], + relation_field: "ForeignKeyField", ) -> None: """ Upsert through model for m2m relation. From 1c24ade8c82f0473ba4b1b010e50660a2e484676 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 16 Apr 2021 14:14:24 +0200 Subject: [PATCH 5/6] fix __all__ error in exclude, update docs --- docs/models/methods.md | 82 +++++++++++++- docs/relations/foreign-key.md | 60 +++++++++++ docs/relations/many-to-many.md | 132 ++++++++++++++++++++--- docs/releases.md | 2 + docs/transactions.md | 88 +++++++++++++++ mkdocs.yml | 1 + ormar/models/mixins/excludable_mixin.py | 16 ++- ormar/models/mixins/save_mixin.py | 4 +- ormar/models/newbasemodel.py | 26 ++++- tests/test_fastapi/test_nested_saving.py | 87 +++++++++++---- 10 files changed, 441 insertions(+), 57 deletions(-) create mode 100644 docs/transactions.md diff --git a/docs/models/methods.md b/docs/models/methods.md index 3b81559..e9cfe72 100644 --- a/docs/models/methods.md +++ b/docs/models/methods.md @@ -198,10 +198,88 @@ or it can be a dictionary that can also contain nested items. To read more about the structure of possible values passed to `exclude` check `Queryset.fields` method documentation. !!!warning - To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models, + To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models on each branch of relation tree, and won't perform nested `save_related` on Models that were already visited. - So if you have a diamond or circular relations types you need to perform the updates in a manual way. + So if you have circular relations types you need to perform the updates in a manual way. + +Note that with `save_all=True` and `follow=True` you can use `save_related()` to save whole relation tree at once. + +Example: + +```python +class Department(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + department_name: str = ormar.String(max_length=100) + + +class Course(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + course_name: str = ormar.String(max_length=100) + completed: bool = ormar.Boolean() + department: Optional[Department] = ormar.ForeignKey(Department) + + +class Student(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + courses = ormar.ManyToMany(Course) + +to_save = { + "department_name": "Ormar", + "courses": [ + {"course_name": "basic1", + "completed": True, + "students": [ + {"name": "Jack"}, + {"name": "Abi"} + ]}, + {"course_name": "basic2", + "completed": True, + "students": [ + {"name": "Kate"}, + {"name": "Miranda"} + ] + }, + ], + } +# initializa whole tree +department = Department(**to_save) + +# save all at once (one after another) +await department.save_related(follow=True, save_all=True) + +department_check = await Department.objects.select_all(follow=True).get() + +to_exclude = { + "id": ..., + "courses": { + "id": ..., + "students": {"id", "studentcourse"} + } +} +# after excluding ids and through models you get exact same payload used to +# construct whole tree +assert department_check.dict(exclude=to_exclude) == to_save + +``` + + +!!!warning + `save_related()` iterates all relations and all models and upserts() them one by one, + so it will save all models but might not be optimal in regard of number of database queries. [fields]: ../fields.md [relations]: ../relations/index.md diff --git a/docs/relations/foreign-key.md b/docs/relations/foreign-key.md index 73977f7..1bdc4f4 100644 --- a/docs/relations/foreign-key.md +++ b/docs/relations/foreign-key.md @@ -27,6 +27,66 @@ By default it's child (source) `Model` name + s, like courses in snippet below: Reverse relation exposes API to manage related objects also from parent side. +### Skipping reverse relation + +If you are sure you don't want the reverse relation you can use `skip_reverse=True` +flag of the `ForeignKey`. + + If you set `skip_reverse` flag internally the field is still registered on the other + side of the relationship so you can: + * `filter` by related models fields from reverse model + * `order_by` by related models fields from reverse model + + But you cannot: + * access the related field from reverse model with `related_name` + * even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by` over the relation) + * the relation won't be populated in `dict()` and `json()` + * you cannot pass the nested related objects when populating from dictionary or json (also through `fastapi`). It will be either ignored or error will be raised depending on `extra` setting in pydantic `Config`. + +Example: + +```python +class Author(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + first_name: str = ormar.String(max_length=80) + last_name: str = ormar.String(max_length=80) + + +class Post(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True) + +# create sample data +author = Author(first_name="Test", last_name="Author") +post = Post(title="Test Post", author=author) + +assert post.author == author # ok +assert author.posts # Attribute error! + +# but still can use in order_by +authors = ( + await Author.objects.select_related("posts").order_by("posts__title").all() +) +assert authors[0].first_name == "Test" + +# note that posts are not populated for author even if explicitly +# included in select_related - note no posts in dict() +assert author.dict(exclude={"id"}) == {"first_name": "Test", "last_name": "Author"} + +# still can filter through fields of related model +authors = await Author.objects.filter(posts__title="Test Post").all() +assert authors[0].first_name == "Test" +assert len(authors) == 1 +``` + + ### add Adding child model from parent side causes adding related model to currently loaded parent relation, diff --git a/docs/relations/many-to-many.md b/docs/relations/many-to-many.md index 24be745..414f0df 100644 --- a/docs/relations/many-to-many.md +++ b/docs/relations/many-to-many.md @@ -20,6 +20,122 @@ post = await Post.objects.create(title="Hello, M2M", author=guido) news = await Category.objects.create(name="News") ``` +## Reverse relation + +`ForeignKey` fields are automatically registering reverse side of the relation. + +By default it's child (source) `Model` name + s, like courses in snippet below: + +```python +class Category(ormar.Model): + class Meta(BaseMeta): + tablename = "categories" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=40) + + +class Post(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + categories: Optional[List[Category]] = ormar.ManyToMany(Category) + +# create some sample data +post = await Post.objects.create(title="Hello, M2M") +news = await Category.objects.create(name="News") +await post.categories.add(news) + +# now you can query and access from both sides: +post_check = Post.objects.select_related("categories").get() +assert post_check.categories[0] == news + +# query through auto registered reverse side +category_check = Category.objects.select_related("posts").get() +assert category_check.posts[0] == post +``` + +Reverse relation exposes API to manage related objects also from parent side. + +### related_name + +By default, the related_name is generated in the same way as for the `ForeignKey` relation (class.name.lower()+'s'), +but in the same way you can overwrite this name by providing `related_name` parameter like below: + +```Python +categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany( + Category, through=PostCategory, related_name="new_categories" + ) +``` + +!!!warning + When you provide multiple relations to the same model `ormar` can no longer auto generate + the `related_name` for you. Therefore, in that situation you **have to** provide `related_name` + for all but one (one can be default and generated) or all related fields. + + +### Skipping reverse relation + +If you are sure you don't want the reverse relation you can use `skip_reverse=True` +flag of the `ManyToMany`. + + If you set `skip_reverse` flag internally the field is still registered on the other + side of the relationship so you can: + * `filter` by related models fields from reverse model + * `order_by` by related models fields from reverse model + + But you cannot: + * access the related field from reverse model with `related_name` + * even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by` over the relation) + * the relation won't be populated in `dict()` and `json()` + * you cannot pass the nested related objects when populating from dictionary or json (also through `fastapi`). It will be either ignored or error will be raised depending on `extra` setting in pydantic `Config`. + +Example: + +```python +class Category(ormar.Model): + class Meta(BaseMeta): + tablename = "categories" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=40) + + +class Post(ormar.Model): + class Meta(BaseMeta): + pass + + id: int = ormar.Integer(primary_key=True) + title: str = ormar.String(max_length=200) + categories: Optional[List[Category]] = ormar.ManyToMany(Category, skip_reverse=True) + +# create some sample data +post = await Post.objects.create(title="Hello, M2M") +news = await Category.objects.create(name="News") +await post.categories.add(news) + +assert post.categories[0] == news # ok +assert news.posts # Attribute error! + +# but still can use in order_by +categories = ( + await Category.objects.select_related("posts").order_by("posts__title").all() +) +assert categories[0].first_name == "Test" + +# note that posts are not populated for author even if explicitly +# included in select_related - note no posts in dict() +assert news.dict(exclude={"id"}) == {"name": "News"} + +# still can filter through fields of related model +categories = await Category.objects.filter(posts__title="Hello, M2M").all() +assert categories[0].name == "News" +assert len(categories) == 1 +``` + + ## Through Model Optionally if you want to add additional fields you can explicitly create and pass @@ -220,22 +336,6 @@ Reverse relation exposes QuerysetProxy API that allows you to query related mode To read which methods of QuerySet are available read below [querysetproxy][querysetproxy] -## related_name - -By default, the related_name is generated in the same way as for the `ForeignKey` relation (class.name.lower()+'s'), -but in the same way you can overwrite this name by providing `related_name` parameter like below: - -```Python -categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany( - Category, through=PostCategory, related_name="new_categories" - ) -``` - -!!!warning - When you provide multiple relations to the same model `ormar` can no longer auto generate - the `related_name` for you. Therefore, in that situation you **have to** provide `related_name` - for all but one (one can be default and generated) or all related fields. - [queries]: ./queries.md [querysetproxy]: ./queryset-proxy.md diff --git a/docs/releases.md b/docs/releases.md index b9becc1..667f170 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -30,10 +30,12 @@ * Fix weakref `ReferenceError` error [#118](https://github.com/collerek/ormar/issues/118) * Fix error raised by Through fields when pydantic `Config.extra="forbid"` is set * Fix bug with `pydantic.PrivateAttr` not being initialized at `__init__` [#149](https://github.com/collerek/ormar/issues/149) +* Fix bug with pydantic-type `exclude` in `dict()` with `__all__` key not working ## 💬 Other * Introduce link to `sqlalchemy-to-ormar` auto-translator for models * Provide links to fastapi ecosystem libraries that support `ormar` +* Add transactions to docs (supported with `databases`) # 0.10.2 diff --git a/docs/transactions.md b/docs/transactions.md new file mode 100644 index 0000000..2b8a904 --- /dev/null +++ b/docs/transactions.md @@ -0,0 +1,88 @@ +# Transactions + +Database transactions are supported thanks to `encode/databases` which is used to issue async queries. + +## Basic usage + +To use transactions use `database.transaction` as async context manager: + +```python +async with database.transaction(): + # everyting called here will be one transaction + await Model1().save() + await Model2().save() + ... +``` + +!!!note + Note that it has to be the same `database` that the one used in Model's `Meta` class. + +To avoid passing `database` instance around in your code you can extract the instance from each `Model`. +Database provided during declaration of `ormar.Model` is available through `Meta.database` and can +be reached from both class and instance. + +```python +import databases +import sqlalchemy +import ormar + +metadata = sqlalchemy.MetaData() +database = databases.Database("sqlite:///") + +class Author(ormar.Model): + class Meta: + database=database + metadata=metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=255) + +# database is accessible from class +database = Author.Meta.database + +# as well as from instance +author = Author(name="Stephen King") +database = author.Meta.database + +``` + +You can also use `.transaction()` as a function decorator on any async function: + +```python +@database.transaction() +async def create_users(request): + ... +``` + +Transaction blocks are managed as task-local state. Nested transactions +are fully supported, and are implemented using database savepoints. + +## Manual commits/ rollbacks + +For a lower-level transaction API you can trigger it manually + +```python +transaction = await database.transaction() +try: + await transaction.start() + ... +except: + await transaction.rollback() +else: + await transaction.commit() +``` + + +## Testing + +Transactions can also be useful during testing when you can apply force rollback +and you do not have to clean the data after each test. + +```python +@pytest.mark.asyncio +async def sample_test(): + async with database: + async with database.transaction(force_rollback=True): + # your test code here + ... +``` \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 5432018..65d5215 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,6 +31,7 @@ nav: - queries/pagination-and-rows-number.md - queries/aggregations.md - Signals: signals.md + - Transactions: transactions.md - Use with Fastapi: fastapi.md - Use with mypy: mypy.md - PyCharm plugin: plugin.md diff --git a/ormar/models/mixins/excludable_mixin.py b/ormar/models/mixins/excludable_mixin.py index 3a2bb04..cbe4c25 100644 --- a/ormar/models/mixins/excludable_mixin.py +++ b/ormar/models/mixins/excludable_mixin.py @@ -14,7 +14,6 @@ from typing import ( from ormar.models.excludable import ExcludableItems from ormar.models.mixins.relation_mixin import RelationMixin -from ormar.queryset.utils import translate_list_to_dict, update if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -138,9 +137,7 @@ class ExcludableMixin(RelationMixin): return columns @classmethod - def _update_excluded_with_related( - cls, exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None], - ) -> Union[Set, Dict]: + def _update_excluded_with_related(cls, exclude: Union[Set, Dict, None],) -> Set: """ Used during generation of the dict(). To avoid cyclical references and max recurrence limit nested models have to @@ -151,8 +148,6 @@ class ExcludableMixin(RelationMixin): :param exclude: set/dict with fields to exclude :type exclude: Union[Set, Dict, None] - :param nested: flag setting nested models (child of previous one, not main one) - :type nested: bool :return: set or dict with excluded fields added. :rtype: Union[Set, Dict] """ @@ -160,10 +155,11 @@ class ExcludableMixin(RelationMixin): related_set = cls.extract_related_names() if isinstance(exclude, set): exclude = {s for s in exclude} - exclude.union(related_set) - else: - related_dict = translate_list_to_dict(related_set) - exclude = update(related_dict, exclude) + exclude = exclude.union(related_set) + elif isinstance(exclude, dict): + # relations are handled in ormar - take only own fields (ellipsis in dict) + exclude = {k for k, v in exclude.items() if v is Ellipsis} + exclude = exclude.union(related_set) return exclude @classmethod diff --git a/ormar/models/mixins/save_mixin.py b/ormar/models/mixins/save_mixin.py index 5ff29ad..a7ac562 100644 --- a/ormar/models/mixins/save_mixin.py +++ b/ormar/models/mixins/save_mixin.py @@ -222,9 +222,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin): @staticmethod async def _upsert_through_model( - instance: "Model", - previous_model: "Model", - relation_field: "ForeignKeyField", + instance: "Model", previous_model: "Model", relation_field: "ForeignKeyField", ) -> None: """ Upsert through model for m2m relation. diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index af69e38..4e16fd0 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -514,7 +514,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass fields = [ field for field in fields - if field not in exclude or exclude.get(field) is not Ellipsis + if field not in exclude + or ( + exclude.get(field) is not Ellipsis + and exclude.get(field) != {"__all__"} + ) ] return fields @@ -567,6 +571,18 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass result = self.get_child(items, key) return result if result is not Ellipsis else default_return + def _convert_all(self, items: Union[Set, Dict, None]) -> Union[Set, Dict, None]: + """ + Helper to convert __all__ pydantic special index to ormar which does not + support index based exclusions. + + :param items: current include/exclude value + :type items: Union[Set, Dict, None] + """ + if isinstance(items, dict) and "__all__" in items: + return items.get("__all__") + return items + def _extract_nested_models( # noqa: CCR001 self, relation_map: Dict, @@ -603,8 +619,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass relation_map, field, default_return=dict() ), models=nested_model, - include=self._skip_ellipsis(include, field), - exclude=self._skip_ellipsis(exclude, field), + include=self._convert_all(self._skip_ellipsis(include, field)), + exclude=self._convert_all(self._skip_ellipsis(exclude, field)), ) elif nested_model is not None: @@ -612,8 +628,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass relation_map=self._skip_ellipsis( relation_map, field, default_return=dict() ), - include=self._skip_ellipsis(include, field), - exclude=self._skip_ellipsis(exclude, field), + include=self._convert_all(self._skip_ellipsis(include, field)), + exclude=self._convert_all(self._skip_ellipsis(exclude, field)), ) else: dict_instance[field] = None diff --git a/tests/test_fastapi/test_nested_saving.py b/tests/test_fastapi/test_nested_saving.py index bb388e7..c4d817b 100644 --- a/tests/test_fastapi/test_nested_saving.py +++ b/tests/test_fastapi/test_nested_saving.py @@ -1,5 +1,5 @@ import json -from typing import List, Optional +from typing import Optional import databases import pytest @@ -50,6 +50,16 @@ class Course(ormar.Model): department: Optional[Department] = ormar.ForeignKey(Department) +class Student(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + courses = ormar.ManyToMany(Course) + + # create db and tables @pytest.fixture(autouse=True, scope="module") def create_test_database(): @@ -59,19 +69,49 @@ def create_test_database(): metadata.drop_all(engine) -@app.post("/DepartmentWithCourses/", response_model=Department) +to_exclude = { + "id": ..., + "courses": { + "__all__": {"id": ..., "students": {"__all__": {"id", "studentcourse"}}} + }, +} + +exclude_all = {"id": ..., "courses": {"__all__"}} + +to_exclude_ormar = { + "id": ..., + "courses": {"id": ..., "students": {"id", "studentcourse"}}, +} + + +@app.post("/departments/", response_model=Department) async def create_department(department: Department): - # there is no save all - you need to split into save and save_related - await department.save() await department.save_related(follow=True, save_all=True) return department -@app.get("/DepartmentsAll/", response_model=List[Department]) -async def get_Courses(): - # if you don't provide default name it related model name + s so courses not course - departmentall = await Department.objects.select_related("courses").all() - return departmentall +@app.get("/departments/{department_name}") +async def get_department(department_name: str): + department = await Department.objects.select_all(follow=True).get( + department_name=department_name + ) + return department.dict(exclude=to_exclude) + + +@app.get("/departments/{department_name}/second") +async def get_department_exclude(department_name: str): + department = await Department.objects.select_all(follow=True).get( + department_name=department_name + ) + return department.dict(exclude=to_exclude_ormar) + + +@app.get("/departments/{department_name}/exclude") +async def get_department_exclude_all(department_name: str): + department = await Department.objects.select_all(follow=True).get( + department_name=department_name + ) + return department.dict(exclude=exclude_all) def test_saving_related_in_fastapi(): @@ -80,11 +120,19 @@ def test_saving_related_in_fastapi(): payload = { "department_name": "Ormar", "courses": [ - {"course_name": "basic1", "completed": True}, - {"course_name": "basic2", "completed": True}, + { + "course_name": "basic1", + "completed": True, + "students": [{"name": "Jack"}, {"name": "Abi"}], + }, + { + "course_name": "basic2", + "completed": True, + "students": [{"name": "Kate"}, {"name": "Miranda"}], + }, ], } - response = client.post("/DepartmentWithCourses/", data=json.dumps(payload)) + response = client.post("/departments/", data=json.dumps(payload)) department = Department(**response.json()) assert department.id is not None @@ -95,12 +143,9 @@ def test_saving_related_in_fastapi(): assert department.courses[1].course_name == "basic2" assert department.courses[1].completed - response = client.get("/DepartmentsAll/") - departments = [Department(**x) for x in response.json()] - assert departments[0].id is not None - assert len(departments[0].courses) == 2 - assert departments[0].department_name == "Ormar" - assert departments[0].courses[0].course_name == "basic1" - assert departments[0].courses[0].completed - assert departments[0].courses[1].course_name == "basic2" - assert departments[0].courses[1].completed + response = client.get("/departments/Ormar") + response2 = client.get("/departments/Ormar/second") + assert response.json() == response2.json() == payload + + response3 = client.get("/departments/Ormar/exclude") + assert response3.json() == {"department_name": "Ormar"} From 15e12ef55b73cd33f9cc194cfa4b8d4b6d91bcc5 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 16 Apr 2021 16:27:07 +0200 Subject: [PATCH 6/6] allow customization of through model relation names --- docs/relations/many-to-many.md | 67 ++++++++++++++- docs/releases.md | 54 ++++++++++++ ormar/fields/base.py | 6 ++ ormar/fields/foreign_key.py | 18 ++-- ormar/fields/many_to_many.py | 5 ++ ormar/models/helpers/relations.py | 2 + ...ustomizing_through_model_relation_names.py | 84 +++++++++++++++++++ 7 files changed, 223 insertions(+), 13 deletions(-) create mode 100644 tests/test_relations/test_customizing_through_model_relation_names.py diff --git a/docs/relations/many-to-many.md b/docs/relations/many-to-many.md index 414f0df..be48cdb 100644 --- a/docs/relations/many-to-many.md +++ b/docs/relations/many-to-many.md @@ -161,7 +161,72 @@ The default naming convention is: it would be `PostCategory` * for table name it similar but with underscore in between and s in the end of class lowercase name, in example above would be `posts_categorys` - + +### Customizing Through relation names + +By default `Through` model relation names default to related model name in lowercase. + +So in example like this: +```python +... # course declaration ommited +class Student(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + courses = ormar.ManyToMany(Course) + +# will produce default Through model like follows (example simplified) +class StudentCourse(ormar.Model): + class Meta: + database = database + metadata = metadata + tablename = "students_courses" + + id: int = ormar.Integer(primary_key=True) + student = ormar.ForeignKey(Student) # default name + course = ormar.ForeignKey(Course) # default name +``` + +To customize the names of fields/relation in Through model now you can use new parameters to `ManyToMany`: + +* `through_relation_name` - name of the field leading to the model in which `ManyToMany` is declared +* `through_reverse_relation_name` - name of the field leading to the model to which `ManyToMany` leads to + +Example: + +```python +... # course declaration ommited +class Student(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + courses = ormar.ManyToMany(Course, + through_relation_name="student_id", + through_reverse_relation_name="course_id") + +# will produce Through model like follows (example simplified) +class StudentCourse(ormar.Model): + class Meta: + database = database + metadata = metadata + tablename = "students_courses" + + id: int = ormar.Integer(primary_key=True) + student_id = ormar.ForeignKey(Student) # set by through_relation_name + course_id = ormar.ForeignKey(Course) # set by through_reverse_relation_name +``` + +!!!note + Note that explicitly declaring relations in Through model is forbidden, so even if you + provide your own custom Through model you cannot change the names there and you need to use + same `through_relation_name` and `through_reverse_relation_name` parameters. + ## Through Fields The through field is auto added to the reverse side of the relation. diff --git a/docs/releases.md b/docs/releases.md index 667f170..f2328cc 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -24,6 +24,60 @@ in those cases you don't have to split save into two calls (`save()` and `save_related()`) * it supports also `ManyToMany` relations * it supports also optional `Through` model values for m2m relations +* Add possibility to customize `Through` model relation field names. + * By default `Through` model relation names default to related model name in lowercase. + So in example like this: + ```python + ... # course declaration ommited + class Student(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + courses = ormar.ManyToMany(Course) + + # will produce default Through model like follows (example simplified) + class StudentCourse(ormar.Model): + class Meta: + database = database + metadata = metadata + tablename = "students_courses" + + id: int = ormar.Integer(primary_key=True) + student = ormar.ForeignKey(Student) # default name + course = ormar.ForeignKey(Course) # default name + ``` + * To customize the names of fields/relation in Through model now you can use new parameters to `ManyToMany`: + * `through_relation_name` - name of the field leading to the model in which `ManyToMany` is declared + * `through_reverse_relation_name` - name of the field leading to the model to which `ManyToMany` leads to + + Example: + ```python + ... # course declaration ommited + class Student(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + courses = ormar.ManyToMany(Course, + through_relation_name="student_id", + through_reverse_relation_name="course_id") + + # will produce default Through model like follows (example simplified) + class StudentCourse(ormar.Model): + class Meta: + database = database + metadata = metadata + tablename = "students_courses" + + id: int = ormar.Integer(primary_key=True) + student_id = ormar.ForeignKey(Student) # set by through_relation_name + course_id = ormar.ForeignKey(Course) # set by through_reverse_relation_name + ``` ## 🐛 Fixes diff --git a/ormar/fields/base.py b/ormar/fields/base.py index a86a500..c435ac6 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -53,6 +53,12 @@ class BaseField(FieldInfo): "is_relation", None ) # ForeignKeyField + subclasses self.is_through: bool = kwargs.pop("is_through", False) # ThroughFields + + self.through_relation_name = kwargs.pop("through_relation_name", None) + self.through_reverse_relation_name = kwargs.pop( + "through_reverse_relation_name", None + ) + self.skip_reverse: bool = kwargs.pop("skip_reverse", False) self.skip_field: bool = kwargs.pop("skip_field", False) diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index fe7c812..feba37c 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -318,29 +318,23 @@ class ForeignKeyField(BaseField): """ return self.related_name or self.owner.get_name() + "s" - def default_target_field_name(self, reverse: bool = False) -> str: + def default_target_field_name(self) -> str: """ Returns default target model name on through model. - :param reverse: flag to grab name without accessing related field - :type reverse: bool :return: name of the field :rtype: str """ - self_rel_prefix = "from_" if not reverse else "to_" - prefix = self_rel_prefix if self.self_reference else "" - return f"{prefix}{self.to.get_name()}" + prefix = "from_" if self.self_reference else "" + return self.through_reverse_relation_name or f"{prefix}{self.to.get_name()}" - def default_source_field_name(self, reverse: bool = False) -> str: + def default_source_field_name(self) -> str: """ Returns default target model name on through model. - :param reverse: flag to grab name without accessing related field - :type reverse: bool :return: name of the field :rtype: str """ - self_rel_prefix = "to_" if not reverse else "from_" - prefix = self_rel_prefix if self.self_reference else "" - return f"{prefix}{self.owner.get_name()}" + prefix = "to_" if self.self_reference else "" + return self.through_relation_name or f"{prefix}{self.owner.get_name()}" def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None: """ diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index a70f623..5f98fd9 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -122,6 +122,9 @@ def ManyToMany( skip_reverse = kwargs.pop("skip_reverse", False) skip_field = kwargs.pop("skip_field", False) + through_relation_name = kwargs.pop("through_relation_name", None) + through_reverse_relation_name = kwargs.pop("through_reverse_relation_name", None) + if through is not None and through.__class__ != ForwardRef: forbid_through_relations(cast(Type["Model"], through)) @@ -158,6 +161,8 @@ def ManyToMany( related_orders_by=related_orders_by, skip_reverse=skip_reverse, skip_field=skip_field, + through_relation_name=through_relation_name, + through_reverse_relation_name=through_reverse_relation_name, ) Field = type("ManyToMany", (ManyToManyField, BaseField), {}) diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index 39e74e3..558c0b3 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -112,6 +112,8 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None: self_reference_primary=model_field.self_reference_primary, orders_by=model_field.related_orders_by, skip_field=model_field.skip_reverse, + through_relation_name=model_field.through_reverse_relation_name, + through_reverse_relation_name=model_field.through_relation_name, ) # register foreign keys on through model model_field = cast("ManyToManyField", model_field) diff --git a/tests/test_relations/test_customizing_through_model_relation_names.py b/tests/test_relations/test_customizing_through_model_relation_names.py new file mode 100644 index 0000000..ff99065 --- /dev/null +++ b/tests/test_relations/test_customizing_through_model_relation_names.py @@ -0,0 +1,84 @@ +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) + + +class Course(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + course_name: str = ormar.String(max_length=100) + + +class Student(ormar.Model): + class Meta: + database = database + metadata = metadata + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + courses = ormar.ManyToMany( + Course, + through_relation_name="student_id", + through_reverse_relation_name="course_id", + ) + + +# create db and tables +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_tables_columns(): + through_meta = Student.Meta.model_fields["courses"].through.Meta + assert "course_id" in through_meta.table.c + assert "student_id" in through_meta.table.c + assert "course_id" in through_meta.model_fields + assert "student_id" in through_meta.model_fields + + +@pytest.mark.asyncio +async def test_working_with_changed_through_names(): + async with database: + async with database.transaction(force_rollback=True): + to_save = { + "course_name": "basic1", + "students": [{"name": "Jack"}, {"name": "Abi"}], + } + await Course(**to_save).save_related(follow=True, save_all=True) + course_check = await Course.objects.select_related("students").get() + + assert course_check.course_name == "basic1" + assert course_check.students[0].name == "Jack" + assert course_check.students[1].name == "Abi" + + students = await course_check.students.all() + assert len(students) == 2 + + student = await course_check.students.get(name="Jack") + assert student.name == "Jack" + + students = await Student.objects.select_related("courses").all( + courses__course_name="basic1" + ) + assert len(students) == 2 + + course_check = ( + await Course.objects.select_related("students") + .order_by("students__name") + .get() + ) + assert course_check.students[0].name == "Abi" + assert course_check.students[1].name == "Jack"