Merge pull request #150 from collerek/relations_params

One sided relations and more powerful save_related
This commit is contained in:
collerek
2021-04-16 16:39:21 +02:00
committed by GitHub
30 changed files with 1825 additions and 147 deletions

View File

@ -47,15 +47,35 @@ since they actually have to create and connect to database in most of the tests.
Yet remember that those are - well - tests and not all solutions are suitable to be used in real life applications.
### Part of the `fastapi` ecosystem
As part of the fastapi ecosystem `ormar` is supported in libraries that somehow work with databases.
As of now `ormar` is supported by:
* [`fastapi-users`](https://github.com/frankie567/fastapi-users)
* [`fastapi-crudrouter`](https://github.com/awtkns/fastapi-crudrouter)
* [`fastapi-pagination`](https://github.com/uriyyo/fastapi-pagination)
If you maintain or use different library and would like it to support `ormar` let us know how we can help.
### Dependencies
Ormar is built with:
* [`SQLAlchemy core`][sqlalchemy-core] for query building.
* [`sqlalchemy core`][sqlalchemy-core] for query building.
* [`databases`][databases] for cross-database async support.
* [`pydantic`][pydantic] for data validation.
* `typing_extensions` for python 3.6 - 3.7
### Migrating from `sqlalchemy`
If you currently use `sqlalchemy` and would like to switch to `ormar` check out the auto-translation
tool that can help you with translating existing sqlalchemy orm models so you do not have to do it manually.
**Beta** versions available at github: [`sqlalchemy-to-ormar`](https://github.com/collerek/sqlalchemy-to-ormar)
or simply `pip install sqlalchemy-to-ormar`
### Migrations & Database creation
Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide

View File

@ -47,15 +47,35 @@ since they actually have to create and connect to database in most of the tests.
Yet remember that those are - well - tests and not all solutions are suitable to be used in real life applications.
### Part of the `fastapi` ecosystem
As part of the fastapi ecosystem `ormar` is supported in libraries that somehow work with databases.
As of now `ormar` is supported by:
* [`fastapi-users`](https://github.com/frankie567/fastapi-users)
* [`fastapi-crudrouter`](https://github.com/awtkns/fastapi-crudrouter)
* [`fastapi-pagination`](https://github.com/uriyyo/fastapi-pagination)
If you maintain or use different library and would like it to support `ormar` let us know how we can help.
### Dependencies
Ormar is built with:
* [`SQLAlchemy core`][sqlalchemy-core] for query building.
* [`sqlalchemy core`][sqlalchemy-core] for query building.
* [`databases`][databases] for cross-database async support.
* [`pydantic`][pydantic] for data validation.
* `typing_extensions` for python 3.6 - 3.7
### Migrating from `sqlalchemy`
If you currently use `sqlalchemy` and would like to switch to `ormar` check out the auto-translation
tool that can help you with translating existing sqlalchemy orm models so you do not have to do it manually.
**Beta** versions available at github: [`sqlalchemy-to-ormar`](https://github.com/collerek/sqlalchemy-to-ormar)
or simply `pip install sqlalchemy-to-ormar`
### Migrations & Database creation
Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide

View File

@ -198,10 +198,88 @@ or it can be a dictionary that can also contain nested items.
To read more about the structure of possible values passed to `exclude` check `Queryset.fields` method documentation.
!!!warning
To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models,
To avoid circular updates with `follow=True` set, `save_related` keeps a set of already visited Models on each branch of relation tree,
and won't perform nested `save_related` on Models that were already visited.
So if you have a diamond or circular relations types you need to perform the updates in a manual way.
So if you have circular relations types you need to perform the updates in a manual way.
Note that with `save_all=True` and `follow=True` you can use `save_related()` to save whole relation tree at once.
Example:
```python
class Department(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
department_name: str = ormar.String(max_length=100)
class Course(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
course_name: str = ormar.String(max_length=100)
completed: bool = ormar.Boolean()
department: Optional[Department] = ormar.ForeignKey(Department)
class Student(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(Course)
to_save = {
"department_name": "Ormar",
"courses": [
{"course_name": "basic1",
"completed": True,
"students": [
{"name": "Jack"},
{"name": "Abi"}
]},
{"course_name": "basic2",
"completed": True,
"students": [
{"name": "Kate"},
{"name": "Miranda"}
]
},
],
}
# initializa whole tree
department = Department(**to_save)
# save all at once (one after another)
await department.save_related(follow=True, save_all=True)
department_check = await Department.objects.select_all(follow=True).get()
to_exclude = {
"id": ...,
"courses": {
"id": ...,
"students": {"id", "studentcourse"}
}
}
# after excluding ids and through models you get exact same payload used to
# construct whole tree
assert department_check.dict(exclude=to_exclude) == to_save
```
!!!warning
`save_related()` iterates all relations and all models and upserts() them one by one,
so it will save all models but might not be optimal in regard of number of database queries.
[fields]: ../fields.md
[relations]: ../relations/index.md

View File

@ -27,6 +27,66 @@ By default it's child (source) `Model` name + s, like courses in snippet below:
Reverse relation exposes API to manage related objects also from parent side.
### Skipping reverse relation
If you are sure you don't want the reverse relation you can use `skip_reverse=True`
flag of the `ForeignKey`.
If you set `skip_reverse` flag internally the field is still registered on the other
side of the relationship so you can:
* `filter` by related models fields from reverse model
* `order_by` by related models fields from reverse model
But you cannot:
* access the related field from reverse model with `related_name`
* even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by` over the relation)
* the relation won't be populated in `dict()` and `json()`
* you cannot pass the nested related objects when populating from dictionary or json (also through `fastapi`). It will be either ignored or error will be raised depending on `extra` setting in pydantic `Config`.
Example:
```python
class Author(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
first_name: str = ormar.String(max_length=80)
last_name: str = ormar.String(max_length=80)
class Post(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True)
# create sample data
author = Author(first_name="Test", last_name="Author")
post = Post(title="Test Post", author=author)
assert post.author == author # ok
assert author.posts # Attribute error!
# but still can use in order_by
authors = (
await Author.objects.select_related("posts").order_by("posts__title").all()
)
assert authors[0].first_name == "Test"
# note that posts are not populated for author even if explicitly
# included in select_related - note no posts in dict()
assert author.dict(exclude={"id"}) == {"first_name": "Test", "last_name": "Author"}
# still can filter through fields of related model
authors = await Author.objects.filter(posts__title="Test Post").all()
assert authors[0].first_name == "Test"
assert len(authors) == 1
```
### add
Adding child model from parent side causes adding related model to currently loaded parent relation,

View File

@ -20,6 +20,122 @@ post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
```
## Reverse relation
`ForeignKey` fields are automatically registering reverse side of the relation.
By default it's child (source) `Model` name + s, like courses in snippet below:
```python
class Category(ormar.Model):
class Meta(BaseMeta):
tablename = "categories"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40)
class Post(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories: Optional[List[Category]] = ormar.ManyToMany(Category)
# create some sample data
post = await Post.objects.create(title="Hello, M2M")
news = await Category.objects.create(name="News")
await post.categories.add(news)
# now you can query and access from both sides:
post_check = Post.objects.select_related("categories").get()
assert post_check.categories[0] == news
# query through auto registered reverse side
category_check = Category.objects.select_related("posts").get()
assert category_check.posts[0] == post
```
Reverse relation exposes API to manage related objects also from parent side.
### related_name
By default, the related_name is generated in the same way as for the `ForeignKey` relation (class.name.lower()+'s'),
but in the same way you can overwrite this name by providing `related_name` parameter like below:
```Python
categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany(
Category, through=PostCategory, related_name="new_categories"
)
```
!!!warning
When you provide multiple relations to the same model `ormar` can no longer auto generate
the `related_name` for you. Therefore, in that situation you **have to** provide `related_name`
for all but one (one can be default and generated) or all related fields.
### Skipping reverse relation
If you are sure you don't want the reverse relation you can use `skip_reverse=True`
flag of the `ManyToMany`.
If you set `skip_reverse` flag internally the field is still registered on the other
side of the relationship so you can:
* `filter` by related models fields from reverse model
* `order_by` by related models fields from reverse model
But you cannot:
* access the related field from reverse model with `related_name`
* even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by` over the relation)
* the relation won't be populated in `dict()` and `json()`
* you cannot pass the nested related objects when populating from dictionary or json (also through `fastapi`). It will be either ignored or error will be raised depending on `extra` setting in pydantic `Config`.
Example:
```python
class Category(ormar.Model):
class Meta(BaseMeta):
tablename = "categories"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40)
class Post(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories: Optional[List[Category]] = ormar.ManyToMany(Category, skip_reverse=True)
# create some sample data
post = await Post.objects.create(title="Hello, M2M")
news = await Category.objects.create(name="News")
await post.categories.add(news)
assert post.categories[0] == news # ok
assert news.posts # Attribute error!
# but still can use in order_by
categories = (
await Category.objects.select_related("posts").order_by("posts__title").all()
)
assert categories[0].first_name == "Test"
# note that posts are not populated for author even if explicitly
# included in select_related - note no posts in dict()
assert news.dict(exclude={"id"}) == {"name": "News"}
# still can filter through fields of related model
categories = await Category.objects.filter(posts__title="Hello, M2M").all()
assert categories[0].name == "News"
assert len(categories) == 1
```
## Through Model
Optionally if you want to add additional fields you can explicitly create and pass
@ -46,6 +162,71 @@ The default naming convention is:
* for table name it similar but with underscore in between and s in the end of class
lowercase name, in example above would be `posts_categorys`
### Customizing Through relation names
By default `Through` model relation names default to related model name in lowercase.
So in example like this:
```python
... # course declaration ommited
class Student(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(Course)
# will produce default Through model like follows (example simplified)
class StudentCourse(ormar.Model):
class Meta:
database = database
metadata = metadata
tablename = "students_courses"
id: int = ormar.Integer(primary_key=True)
student = ormar.ForeignKey(Student) # default name
course = ormar.ForeignKey(Course) # default name
```
To customize the names of fields/relation in Through model now you can use new parameters to `ManyToMany`:
* `through_relation_name` - name of the field leading to the model in which `ManyToMany` is declared
* `through_reverse_relation_name` - name of the field leading to the model to which `ManyToMany` leads to
Example:
```python
... # course declaration ommited
class Student(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(Course,
through_relation_name="student_id",
through_reverse_relation_name="course_id")
# will produce Through model like follows (example simplified)
class StudentCourse(ormar.Model):
class Meta:
database = database
metadata = metadata
tablename = "students_courses"
id: int = ormar.Integer(primary_key=True)
student_id = ormar.ForeignKey(Student) # set by through_relation_name
course_id = ormar.ForeignKey(Course) # set by through_reverse_relation_name
```
!!!note
Note that explicitly declaring relations in Through model is forbidden, so even if you
provide your own custom Through model you cannot change the names there and you need to use
same `through_relation_name` and `through_reverse_relation_name` parameters.
## Through Fields
The through field is auto added to the reverse side of the relation.
@ -220,22 +401,6 @@ Reverse relation exposes QuerysetProxy API that allows you to query related mode
To read which methods of QuerySet are available read below [querysetproxy][querysetproxy]
## related_name
By default, the related_name is generated in the same way as for the `ForeignKey` relation (class.name.lower()+'s'),
but in the same way you can overwrite this name by providing `related_name` parameter like below:
```Python
categories: Optional[Union[Category, List[Category]]] = ormar.ManyToMany(
Category, through=PostCategory, related_name="new_categories"
)
```
!!!warning
When you provide multiple relations to the same model `ormar` can no longer auto generate
the `related_name` for you. Therefore, in that situation you **have to** provide `related_name`
for all but one (one can be default and generated) or all related fields.
[queries]: ./queries.md
[querysetproxy]: ./queryset-proxy.md

View File

@ -1,3 +1,97 @@
# 0.10.3
## ✨ Features
* `ForeignKey` and `ManyToMany` now support `skip_reverse: bool = False` flag [#118](https://github.com/collerek/ormar/issues/118).
If you set `skip_reverse` flag internally the field is still registered on the other
side of the relationship so you can:
* `filter` by related models fields from reverse model
* `order_by` by related models fields from reverse model
But you cannot:
* access the related field from reverse model with `related_name`
* even if you `select_related` from reverse side of the model the returned models won't be populated in reversed instance (the join is not prevented so you still can `filter` and `order_by`)
* the relation won't be populated in `dict()` and `json()`
* you cannot pass the nested related objects when populating from `dict()` or `json()` (also through `fastapi`). It will be either ignored or raise error depending on `extra` setting in pydantic `Config`.
* `Model.save_related()` now can save whole data tree in once [#148](https://github.com/collerek/ormar/discussions/148)
meaning:
* it knows if it should save main `Model` or related `Model` first to preserve the relation
* it saves main `Model` if
* it's not `saved`,
* has no `pk` value
* or `save_all=True` flag is set
in those cases you don't have to split save into two calls (`save()` and `save_related()`)
* it supports also `ManyToMany` relations
* it supports also optional `Through` model values for m2m relations
* Add possibility to customize `Through` model relation field names.
* By default `Through` model relation names default to related model name in lowercase.
So in example like this:
```python
... # course declaration ommited
class Student(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(Course)
# will produce default Through model like follows (example simplified)
class StudentCourse(ormar.Model):
class Meta:
database = database
metadata = metadata
tablename = "students_courses"
id: int = ormar.Integer(primary_key=True)
student = ormar.ForeignKey(Student) # default name
course = ormar.ForeignKey(Course) # default name
```
* To customize the names of fields/relation in Through model now you can use new parameters to `ManyToMany`:
* `through_relation_name` - name of the field leading to the model in which `ManyToMany` is declared
* `through_reverse_relation_name` - name of the field leading to the model to which `ManyToMany` leads to
Example:
```python
... # course declaration ommited
class Student(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(Course,
through_relation_name="student_id",
through_reverse_relation_name="course_id")
# will produce default Through model like follows (example simplified)
class StudentCourse(ormar.Model):
class Meta:
database = database
metadata = metadata
tablename = "students_courses"
id: int = ormar.Integer(primary_key=True)
student_id = ormar.ForeignKey(Student) # set by through_relation_name
course_id = ormar.ForeignKey(Course) # set by through_reverse_relation_name
```
## 🐛 Fixes
* Fix weakref `ReferenceError` error [#118](https://github.com/collerek/ormar/issues/118)
* Fix error raised by Through fields when pydantic `Config.extra="forbid"` is set
* Fix bug with `pydantic.PrivateAttr` not being initialized at `__init__` [#149](https://github.com/collerek/ormar/issues/149)
* Fix bug with pydantic-type `exclude` in `dict()` with `__all__` key not working
## 💬 Other
* Introduce link to `sqlalchemy-to-ormar` auto-translator for models
* Provide links to fastapi ecosystem libraries that support `ormar`
* Add transactions to docs (supported with `databases`)
# 0.10.2
## ✨ Features

88
docs/transactions.md Normal file
View File

@ -0,0 +1,88 @@
# Transactions
Database transactions are supported thanks to `encode/databases` which is used to issue async queries.
## Basic usage
To use transactions use `database.transaction` as async context manager:
```python
async with database.transaction():
# everyting called here will be one transaction
await Model1().save()
await Model2().save()
...
```
!!!note
Note that it has to be the same `database` that the one used in Model's `Meta` class.
To avoid passing `database` instance around in your code you can extract the instance from each `Model`.
Database provided during declaration of `ormar.Model` is available through `Meta.database` and can
be reached from both class and instance.
```python
import databases
import sqlalchemy
import ormar
metadata = sqlalchemy.MetaData()
database = databases.Database("sqlite:///")
class Author(ormar.Model):
class Meta:
database=database
metadata=metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=255)
# database is accessible from class
database = Author.Meta.database
# as well as from instance
author = Author(name="Stephen King")
database = author.Meta.database
```
You can also use `.transaction()` as a function decorator on any async function:
```python
@database.transaction()
async def create_users(request):
...
```
Transaction blocks are managed as task-local state. Nested transactions
are fully supported, and are implemented using database savepoints.
## Manual commits/ rollbacks
For a lower-level transaction API you can trigger it manually
```python
transaction = await database.transaction()
try:
await transaction.start()
...
except:
await transaction.rollback()
else:
await transaction.commit()
```
## Testing
Transactions can also be useful during testing when you can apply force rollback
and you do not have to clean the data after each test.
```python
@pytest.mark.asyncio
async def sample_test():
async with database:
async with database.transaction(force_rollback=True):
# your test code here
...
```

View File

@ -31,6 +31,7 @@ nav:
- queries/pagination-and-rows-number.md
- queries/aggregations.md
- Signals: signals.md
- Transactions: transactions.md
- Use with Fastapi: fastapi.md
- Use with mypy: mypy.md
- PyCharm plugin: plugin.md

View File

@ -75,7 +75,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType()
__version__ = "0.10.2"
__version__ = "0.10.3"
__all__ = [
"Integer",
"BigInteger",

View File

@ -54,6 +54,14 @@ class BaseField(FieldInfo):
) # ForeignKeyField + subclasses
self.is_through: bool = kwargs.pop("is_through", False) # ThroughFields
self.through_relation_name = kwargs.pop("through_relation_name", None)
self.through_reverse_relation_name = kwargs.pop(
"through_reverse_relation_name", None
)
self.skip_reverse: bool = kwargs.pop("skip_reverse", False)
self.skip_field: bool = kwargs.pop("skip_field", False)
self.owner: Type["Model"] = kwargs.pop("owner", None)
self.to: Type["Model"] = kwargs.pop("to", None)
self.through: Type["Model"] = kwargs.pop("through", None)

View File

@ -233,9 +233,13 @@ def ForeignKey( # noqa CFQ002
owner = kwargs.pop("owner", None)
self_reference = kwargs.pop("self_reference", False)
orders_by = kwargs.pop("orders_by", None)
related_orders_by = kwargs.pop("related_orders_by", None)
skip_reverse = kwargs.pop("skip_reverse", False)
skip_field = kwargs.pop("skip_field", False)
validate_not_allowed_fields(kwargs)
if to.__class__ == ForwardRef:
@ -274,6 +278,8 @@ def ForeignKey( # noqa CFQ002
is_relation=True,
orders_by=orders_by,
related_orders_by=related_orders_by,
skip_reverse=skip_reverse,
skip_field=skip_field,
)
Field = type("ForeignKey", (ForeignKeyField, BaseField), {})
@ -312,6 +318,24 @@ class ForeignKeyField(BaseField):
"""
return self.related_name or self.owner.get_name() + "s"
def default_target_field_name(self) -> str:
"""
Returns default target model name on through model.
:return: name of the field
:rtype: str
"""
prefix = "from_" if self.self_reference else ""
return self.through_reverse_relation_name or f"{prefix}{self.to.get_name()}"
def default_source_field_name(self) -> str:
"""
Returns default target model name on through model.
:return: name of the field
:rtype: str
"""
prefix = "to_" if self.self_reference else ""
return self.through_relation_name or f"{prefix}{self.owner.get_name()}"
def evaluate_forward_ref(self, globalns: Any, localns: Any) -> None:
"""
Evaluates the ForwardRef to actual Field based on global and local namespaces

View File

@ -112,11 +112,19 @@ def ManyToMany(
"""
related_name = kwargs.pop("related_name", None)
nullable = kwargs.pop("nullable", True)
owner = kwargs.pop("owner", None)
self_reference = kwargs.pop("self_reference", False)
orders_by = kwargs.pop("orders_by", None)
related_orders_by = kwargs.pop("related_orders_by", None)
skip_reverse = kwargs.pop("skip_reverse", False)
skip_field = kwargs.pop("skip_field", False)
through_relation_name = kwargs.pop("through_relation_name", None)
through_reverse_relation_name = kwargs.pop("through_reverse_relation_name", None)
if through is not None and through.__class__ != ForwardRef:
forbid_through_relations(cast(Type["Model"], through))
@ -151,6 +159,10 @@ def ManyToMany(
is_multi=True,
orders_by=orders_by,
related_orders_by=related_orders_by,
skip_reverse=skip_reverse,
skip_field=skip_field,
through_relation_name=through_relation_name,
through_reverse_relation_name=through_reverse_relation_name,
)
Field = type("ManyToMany", (ManyToManyField, BaseField), {})
@ -184,24 +196,6 @@ class ManyToManyField(ForeignKeyField, ormar.QuerySetProtocol, ormar.RelationPro
or self.name
)
def default_target_field_name(self) -> str:
"""
Returns default target model name on through model.
:return: name of the field
:rtype: str
"""
prefix = "from_" if self.self_reference else ""
return f"{prefix}{self.to.get_name()}"
def default_source_field_name(self) -> str:
"""
Returns default target model name on through model.
:return: name of the field
:rtype: str
"""
prefix = "to_" if self.self_reference else ""
return f"{prefix}{self.owner.get_name()}"
def has_unresolved_forward_refs(self) -> bool:
"""
Verifies if the filed has any ForwardRefs that require updating before the

View File

@ -111,6 +111,9 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None:
self_reference=model_field.self_reference,
self_reference_primary=model_field.self_reference_primary,
orders_by=model_field.related_orders_by,
skip_field=model_field.skip_reverse,
through_relation_name=model_field.through_reverse_relation_name,
through_reverse_relation_name=model_field.through_relation_name,
)
# register foreign keys on through model
model_field = cast("ManyToManyField", model_field)
@ -125,6 +128,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None:
owner=model_field.to,
self_reference=model_field.self_reference,
orders_by=model_field.related_orders_by,
skip_field=model_field.skip_reverse,
)
@ -145,6 +149,7 @@ def register_through_shortcut_fields(model_field: "ManyToManyField") -> None:
virtual=True,
related_name=model_field.name,
owner=model_field.owner,
nullable=True,
)
model_field.to.Meta.model_fields[through_name] = Through(
@ -153,6 +158,7 @@ def register_through_shortcut_fields(model_field: "ManyToManyField") -> None:
virtual=True,
related_name=related_name,
owner=model_field.to,
nullable=True,
)

View File

@ -90,6 +90,7 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
"""
new_model._quick_access_fields = quick_access_set
new_model._related_names = None
new_model._through_names = None
new_model._related_fields = None
new_model._pydantic_fields = {name for name in new_model.__fields__}
new_model._choices_fields = set()
@ -536,6 +537,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
populate_meta_sqlalchemy_table_if_required(new_model.Meta)
expand_reverse_relationships(new_model)
# TODO: iterate only related fields
for field in new_model.Meta.model_fields.values():
register_relation_in_alias_manager(field=field)

View File

@ -14,7 +14,6 @@ from typing import (
from ormar.models.excludable import ExcludableItems
from ormar.models.mixins.relation_mixin import RelationMixin
from ormar.queryset.utils import translate_list_to_dict, update
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -138,9 +137,7 @@ class ExcludableMixin(RelationMixin):
return columns
@classmethod
def _update_excluded_with_related(
cls, exclude: Union["AbstractSetIntStr", "MappingIntStrAny", None],
) -> Union[Set, Dict]:
def _update_excluded_with_related(cls, exclude: Union[Set, Dict, None],) -> Set:
"""
Used during generation of the dict().
To avoid cyclical references and max recurrence limit nested models have to
@ -151,8 +148,6 @@ class ExcludableMixin(RelationMixin):
:param exclude: set/dict with fields to exclude
:type exclude: Union[Set, Dict, None]
:param nested: flag setting nested models (child of previous one, not main one)
:type nested: bool
:return: set or dict with excluded fields added.
:rtype: Union[Set, Dict]
"""
@ -160,10 +155,11 @@ class ExcludableMixin(RelationMixin):
related_set = cls.extract_related_names()
if isinstance(exclude, set):
exclude = {s for s in exclude}
exclude.union(related_set)
else:
related_dict = translate_list_to_dict(related_set)
exclude = update(related_dict, exclude)
exclude = exclude.union(related_set)
elif isinstance(exclude, dict):
# relations are handled in ormar - take only own fields (ellipsis in dict)
exclude = {k for k, v in exclude.items() if v is Ellipsis}
exclude = exclude.union(related_set)
return exclude
@classmethod

View File

@ -4,9 +4,10 @@ from typing import (
Optional,
Set,
TYPE_CHECKING,
cast,
)
from ormar import BaseField
from ormar import BaseField, ForeignKeyField
from ormar.models.traversible import NodeList
@ -20,6 +21,7 @@ class RelationMixin:
Meta: ModelMeta
_related_names: Optional[Set]
_through_names: Optional[Set]
_related_fields: Optional[List]
get_name: Callable
@ -38,7 +40,7 @@ class RelationMixin:
return self_fields
@classmethod
def extract_related_fields(cls) -> List:
def extract_related_fields(cls) -> List["ForeignKeyField"]:
"""
Returns List of ormar Fields for all relations declared on a model.
List is cached in cls._related_fields for quicker access.
@ -51,25 +53,29 @@ class RelationMixin:
related_fields = []
for name in cls.extract_related_names().union(cls.extract_through_names()):
related_fields.append(cls.Meta.model_fields[name])
related_fields.append(cast("ForeignKeyField", cls.Meta.model_fields[name]))
cls._related_fields = related_fields
return related_fields
@classmethod
def extract_through_names(cls) -> Set:
def extract_through_names(cls) -> Set[str]:
"""
Extracts related fields through names which are shortcuts to through models.
:return: set of related through fields names
:rtype: Set
"""
related_fields = set()
for name in cls.extract_related_names():
field = cls.Meta.model_fields[name]
if field.is_multi:
related_fields.add(field.through.get_name(lower=True))
return related_fields
if isinstance(cls._through_names, Set):
return cls._through_names
related_names = set()
for name, field in cls.Meta.model_fields.items():
if isinstance(field, BaseField) and field.is_through:
related_names.add(name)
cls._through_names = related_names
return related_names
@classmethod
def extract_related_names(cls) -> Set[str]:
@ -89,6 +95,7 @@ class RelationMixin:
isinstance(field, BaseField)
and field.is_relation
and not field.is_through
and not field.skip_field
):
related_names.add(name)
cls._related_names = related_names

View File

@ -1,5 +1,5 @@
import uuid
from typing import Dict, Optional, Set, TYPE_CHECKING
from typing import Callable, Collection, Dict, Optional, Set, TYPE_CHECKING, cast
import ormar
from ormar.exceptions import ModelPersistenceError
@ -7,6 +7,9 @@ from ormar.models.helpers.validation import validate_choices
from ormar.models.mixins import AliasMixin
from ormar.models.mixins.relation_mixin import RelationMixin
if TYPE_CHECKING: # pragma: no cover
from ormar import ForeignKeyField, Model
class SavePrepareMixin(RelationMixin, AliasMixin):
"""
@ -15,6 +18,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if TYPE_CHECKING: # pragma: nocover
_choices_fields: Optional[Set]
_skip_ellipsis: Callable
@classmethod
def prepare_model_to_save(cls, new_kwargs: dict) -> dict:
@ -170,3 +174,128 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if field_name in new_kwargs and field_name in cls._choices_fields:
validate_choices(field=field, value=new_kwargs.get(field_name))
return new_kwargs
@staticmethod
async def _upsert_model(
instance: "Model",
save_all: bool,
previous_model: Optional["Model"],
relation_field: Optional["ForeignKeyField"],
update_count: int,
) -> int:
"""
Method updates given instance if:
* instance is not saved or
* instance have no pk or
* save_all=True flag is set
and instance is not __pk_only__.
If relation leading to instance is a ManyToMany also the through model is saved
:param instance: current model to upsert
:type instance: Model
:param save_all: flag if all models should be saved or only not saved ones
:type save_all: bool
:param relation_field: field with relation
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
:param update_count: no of updated models
:type update_count: int
:return: no of updated models
:rtype: int
"""
if (
save_all or not instance.pk or not instance.saved
) and not instance.__pk_only__:
await instance.upsert()
if relation_field and relation_field.is_multi:
await instance._upsert_through_model(
instance=instance,
relation_field=relation_field,
previous_model=cast("Model", previous_model),
)
update_count += 1
return update_count
@staticmethod
async def _upsert_through_model(
instance: "Model", previous_model: "Model", relation_field: "ForeignKeyField",
) -> None:
"""
Upsert through model for m2m relation.
:param instance: current model to upsert
:type instance: Model
:param relation_field: field with relation
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
"""
through_name = previous_model.Meta.model_fields[
relation_field.name
].through.get_name()
through = getattr(instance, through_name)
if through:
through_dict = through.dict(exclude=through.extract_related_names())
else:
through_dict = {}
await getattr(
previous_model, relation_field.name
).queryset_proxy.upsert_through_instance(instance, **through_dict)
async def _update_relation_list(
self,
fields_list: Collection["ForeignKeyField"],
follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
"""
Internal method used in save_related to follow deeper from
related models and update numbers of updated related instances.
:type save_all: flag if all models should be saved
:type save_all: bool
:param fields_list: list of ormar fields to follow and save
:type fields_list: Collection["ForeignKeyField"]
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:param update_count: internal parameter for recursive calls -
number of updated instances
:type update_count: int
:return: tuple of update count and visited
:rtype: int
"""
for field in fields_list:
value = getattr(self, field.name) or []
if not isinstance(value, list):
value = [value]
for val in value:
if follow:
update_count = await val.save_related(
follow=follow,
save_all=save_all,
relation_map=self._skip_ellipsis( # type: ignore
relation_map, field.name, default_return={}
),
update_count=update_count,
previous_model=self,
relation_field=field,
)
else:
update_count = await val._upsert_model(
instance=val,
save_all=save_all,
previous_model=self,
relation_field=field,
update_count=update_count,
)
return update_count

View File

@ -2,6 +2,7 @@ from typing import (
Any,
Dict,
List,
Optional,
Set,
TYPE_CHECKING,
TypeVar,
@ -17,6 +18,9 @@ from ormar.queryset.utils import subtract_dict, translate_list_to_dict
T = TypeVar("T", bound="Model")
if TYPE_CHECKING: # pragma: no cover
from ormar import ForeignKeyField
class Model(ModelRow):
__abstract__ = False
@ -24,7 +28,11 @@ class Model(ModelRow):
Meta: ModelMeta
def __repr__(self) -> str: # pragma nocover
_repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()}
_repr = {
k: getattr(self, k)
for k, v in self.Meta.model_fields.items()
if not v.skip_field
}
return f"{self.__class__.__name__}({str(_repr)})"
async def upsert(self: T, **kwargs: Any) -> T:
@ -106,6 +114,8 @@ class Model(ModelRow):
relation_map: Dict = None,
exclude: Union[Set, Dict] = None,
update_count: int = 0,
previous_model: "Model" = None,
relation_field: Optional["ForeignKeyField"] = None,
) -> int:
"""
Triggers a upsert method on all related models
@ -122,6 +132,10 @@ class Model(ModelRow):
Model A but will never follow into Model C.
Nested relations of those kind need to be persisted manually.
:param relation_field: field with relation leading to this model
:type relation_field: Optional[ForeignKeyField]
:param previous_model: previous model from which method came
:type previous_model: Model
:param exclude: items to exclude during saving of relations
:type exclude: Union[Set, Dict]
:param relation_map: map of relations to follow
@ -147,61 +161,53 @@ class Model(ModelRow):
exclude = translate_list_to_dict(exclude)
relation_map = subtract_dict(relation_map, exclude or {})
for related in self.extract_related_names():
if relation_map and related in relation_map:
value = getattr(self, related)
if value:
update_count = await self._update_and_follow(
value=value,
follow=follow,
save_all=save_all,
relation_map=self._skip_ellipsis( # type: ignore
relation_map, related, default_return={}
),
update_count=update_count,
)
return update_count
if relation_map:
fields_to_visit = {
field
for field in self.extract_related_fields()
if field.name in relation_map
}
pre_save = {
field
for field in fields_to_visit
if not field.virtual and not field.is_multi
}
@staticmethod
async def _update_and_follow(
value: Union["Model", List["Model"]],
follow: bool,
save_all: bool,
relation_map: Dict,
update_count: int,
) -> int:
"""
Internal method used in save_related to follow related models and update numbers
of updated related instances.
:param value: Model to follow
:type value: Model
:param relation_map: map of relations to follow
:type relation_map: Dict
:param follow: flag to trigger deep save -
by default only directly related models are saved
with follow=True also related models of related models are saved
:type follow: bool
:param update_count: internal parameter for recursive calls -
number of updated instances
:type update_count: int
:return: tuple of update count and visited
:rtype: int
"""
if not isinstance(value, list):
value = [value]
for val in value:
if (not val.saved or save_all) and not val.__pk_only__:
await val.upsert()
update_count += 1
if follow:
update_count = await val.save_related(
update_count = await self._update_relation_list(
fields_list=pre_save,
follow=follow,
save_all=save_all,
relation_map=relation_map,
update_count=update_count,
)
update_count = await self._upsert_model(
instance=self,
save_all=save_all,
previous_model=previous_model,
relation_field=relation_field,
update_count=update_count,
)
post_save = fields_to_visit - pre_save
update_count = await self._update_relation_list(
fields_list=post_save,
follow=follow,
save_all=save_all,
relation_map=relation_map,
update_count=update_count,
)
else:
update_count = await self._upsert_model(
instance=self,
save_all=save_all,
previous_model=previous_model,
relation_field=relation_field,
update_count=update_count,
)
return update_count
async def update(self: T, _columns: List[str] = None, **kwargs: Any) -> T:

View File

@ -81,6 +81,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
_orm_id: int
_orm_saved: bool
_related_names: Optional[Set]
_through_names: Optional[Set]
_related_names_hash: str
_choices_fields: Optional[Set]
_pydantic_fields: Set
@ -165,6 +166,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
for field_to_nullify in excluded:
new_kwargs[field_to_nullify] = None
# extract through fields
through_tmp_dict = dict()
for field_name in self.extract_through_names():
through_tmp_dict[field_name] = new_kwargs.pop(field_name, None)
values, fields_set, validation_error = pydantic.validate_model(
self, new_kwargs # type: ignore
)
@ -174,12 +180,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set)
# add back through fields
new_kwargs.update(through_tmp_dict)
# register the columns models after initialization
for related in self.extract_related_names().union(self.extract_through_names()):
self.Meta.model_fields[related].expand_relationship(
new_kwargs.get(related), self, to_register=True,
)
if hasattr(self, "_init_private_attributes"):
# introduced in pydantic 1.7
self._init_private_attributes()
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
"""
Overwrites setattr in object to allow for special behaviour of certain params.
@ -283,6 +296,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
value = object.__getattribute__(self, "__dict__").get(item, None)
value = object.__getattribute__(self, "_convert_json")(item, value, "loads")
return value
return object.__getattribute__(self, item) # pragma: no cover
def _verify_model_can_be_initialized(self) -> None:
@ -500,7 +514,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
fields = [
field
for field in fields
if field not in exclude or exclude.get(field) is not Ellipsis
if field not in exclude
or (
exclude.get(field) is not Ellipsis
and exclude.get(field) != {"__all__"}
)
]
return fields
@ -553,6 +571,18 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
result = self.get_child(items, key)
return result if result is not Ellipsis else default_return
def _convert_all(self, items: Union[Set, Dict, None]) -> Union[Set, Dict, None]:
"""
Helper to convert __all__ pydantic special index to ormar which does not
support index based exclusions.
:param items: current include/exclude value
:type items: Union[Set, Dict, None]
"""
if isinstance(items, dict) and "__all__" in items:
return items.get("__all__")
return items
def _extract_nested_models( # noqa: CCR001
self,
relation_map: Dict,
@ -581,6 +611,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
for field in fields:
if not relation_map or field not in relation_map:
continue
try:
nested_model = getattr(self, field)
if isinstance(nested_model, MutableSequence):
dict_instance[field] = self._extract_nested_models_from_list(
@ -588,19 +619,22 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
relation_map, field, default_return=dict()
),
models=nested_model,
include=self._skip_ellipsis(include, field),
exclude=self._skip_ellipsis(exclude, field),
include=self._convert_all(self._skip_ellipsis(include, field)),
exclude=self._convert_all(self._skip_ellipsis(exclude, field)),
)
elif nested_model is not None:
dict_instance[field] = nested_model.dict(
relation_map=self._skip_ellipsis(
relation_map, field, default_return=dict()
),
include=self._skip_ellipsis(include, field),
exclude=self._skip_ellipsis(exclude, field),
include=self._convert_all(self._skip_ellipsis(include, field)),
exclude=self._convert_all(self._skip_ellipsis(exclude, field)),
)
else:
dict_instance[field] = None
except ReferenceError:
dict_instance[field] = None
return dict_instance
def dict( # type: ignore # noqa A003

View File

@ -16,7 +16,7 @@ from typing import ( # noqa: I100, I201
)
import ormar # noqa: I100, I202
from ormar.exceptions import ModelPersistenceError, QueryDefinitionError
from ormar.exceptions import ModelPersistenceError, NoMatch, QueryDefinitionError
if TYPE_CHECKING: # pragma no cover
from ormar.relations import Relation
@ -152,6 +152,21 @@ class QuerysetProxy(Generic[T]):
through_model = await model_cls.objects.get(**rel_kwargs)
await through_model.update(**kwargs)
async def upsert_through_instance(self, child: "T", **kwargs: Any) -> None:
"""
Updates a through model instance in the database for m2m relations if
it already exists, else creates one.
:param kwargs: dict of additional keyword arguments for through instance
:type kwargs: Any
:param child: child model instance
:type child: Model
"""
try:
await self.update_through_instance(child=child, **kwargs)
except NoMatch:
await self.create_through_instance(child=child, **kwargs)
async def delete_through_instance(self, child: "T") -> None:
"""
Removes through model instance from the database for m2m relations.
@ -251,7 +266,7 @@ class QuerysetProxy(Generic[T]):
owner_column = self._owner.get_name()
else:
queryset = ormar.QuerySet(model_cls=self.relation.to) # type: ignore
owner_column = self.related_field.name
owner_column = self.related_field_name
kwargs = {owner_column: self._owner}
self._clean_items_on_load()
if keep_reversed and self.type_ == ormar.RelationType.REVERSE:
@ -367,7 +382,7 @@ class QuerysetProxy(Generic[T]):
"""
through_kwargs = kwargs.pop(self.through_model_name, {})
if self.type_ == ormar.RelationType.REVERSE:
kwargs[self.related_field.name] = self._owner
kwargs[self.related_field_name] = self._owner
created = await self.queryset.create(**kwargs)
self._register_related(created)
if self.type_ == ormar.RelationType.MULTIPLE:

View File

@ -124,15 +124,14 @@ class RelationProxy(Generic[T], list):
:rtype: QuerySet
"""
related_field_name = self.related_field_name
related_field = self.relation.to.Meta.model_fields[related_field_name]
pkname = self._owner.get_column_alias(self._owner.Meta.pkname)
self._check_if_model_saved()
kwargs = {f"{related_field.name}__{pkname}": self._owner.pk}
kwargs = {f"{related_field_name}__{pkname}": self._owner.pk}
queryset = (
ormar.QuerySet(
model_cls=self.relation.to, proxy_source_model=self._owner.__class__
)
.select_related(related_field.name)
.select_related(related_field_name)
.filter(**kwargs)
)
return queryset
@ -168,10 +167,11 @@ class RelationProxy(Generic[T], list):
super().remove(item)
relation_name = self.related_field_name
relation = item._orm._get(relation_name)
if relation is None: # pragma nocover
raise ValueError(
f"{self._owner.get_name()} does not have relation {relation_name}"
)
# if relation is None: # pragma nocover
# raise ValueError(
# f"{self._owner.get_name()} does not have relation {relation_name}"
# )
if relation:
relation.remove(self._owner)
self.relation.remove(item)
if self.type_ == ormar.RelationType.MULTIPLE:
@ -211,7 +211,7 @@ class RelationProxy(Generic[T], list):
self._check_if_model_saved()
if self.type_ == ormar.RelationType.MULTIPLE:
await self.queryset_proxy.create_through_instance(item, **kwargs)
setattr(item, relation_name, self._owner)
setattr(self._owner, self.field_name, item)
else:
setattr(item, relation_name, self._owner)
await item.update()

View File

@ -1,4 +1,5 @@
import datetime
from typing import List
import pytest
import sqlalchemy
@ -59,6 +60,12 @@ async def get_bus(item_id: int):
return bus
@app.get("/buses/", response_model=List[Bus])
async def get_buses():
buses = await Bus.objects.select_related(["owner", "co_owner"]).all()
return buses
@app.post("/trucks/", response_model=Truck)
async def create_truck(truck: Truck):
await truck.save()
@ -84,6 +91,12 @@ async def add_bus_coowner(item_id: int, person: Person):
return bus
@app.get("/buses2/", response_model=List[Bus2])
async def get_buses2():
buses = await Bus2.objects.select_related(["owner", "co_owners"]).all()
return buses
@app.post("/trucks2/", response_model=Truck2)
async def create_truck2(truck: Truck2):
await truck.save()
@ -172,6 +185,10 @@ def test_inheritance_with_relation():
assert unicorn2.co_owner.name == "Joe"
assert unicorn2.max_persons == 50
buses = [Bus(**x) for x in client.get("/buses/").json()]
assert len(buses) == 1
assert buses[0].name == "Unicorn"
def test_inheritance_with_m2m_relation():
client = TestClient(app)
@ -217,3 +234,7 @@ def test_inheritance_with_m2m_relation():
assert shelby.co_owners[0] == alex
assert shelby.co_owners[1] == joe
assert shelby.max_capacity == 2000
buses = [Bus2(**x) for x in client.get("/buses2/").json()]
assert len(buses) == 1
assert buses[0].name == "Unicorn"

View File

@ -0,0 +1,151 @@
import json
from typing import Optional
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class Department(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
department_name: str = ormar.String(max_length=100)
class Course(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
course_name: str = ormar.String(max_length=100)
completed: bool = ormar.Boolean()
department: Optional[Department] = ormar.ForeignKey(Department)
class Student(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(Course)
# create db and tables
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
to_exclude = {
"id": ...,
"courses": {
"__all__": {"id": ..., "students": {"__all__": {"id", "studentcourse"}}}
},
}
exclude_all = {"id": ..., "courses": {"__all__"}}
to_exclude_ormar = {
"id": ...,
"courses": {"id": ..., "students": {"id", "studentcourse"}},
}
@app.post("/departments/", response_model=Department)
async def create_department(department: Department):
await department.save_related(follow=True, save_all=True)
return department
@app.get("/departments/{department_name}")
async def get_department(department_name: str):
department = await Department.objects.select_all(follow=True).get(
department_name=department_name
)
return department.dict(exclude=to_exclude)
@app.get("/departments/{department_name}/second")
async def get_department_exclude(department_name: str):
department = await Department.objects.select_all(follow=True).get(
department_name=department_name
)
return department.dict(exclude=to_exclude_ormar)
@app.get("/departments/{department_name}/exclude")
async def get_department_exclude_all(department_name: str):
department = await Department.objects.select_all(follow=True).get(
department_name=department_name
)
return department.dict(exclude=exclude_all)
def test_saving_related_in_fastapi():
client = TestClient(app)
with client as client:
payload = {
"department_name": "Ormar",
"courses": [
{
"course_name": "basic1",
"completed": True,
"students": [{"name": "Jack"}, {"name": "Abi"}],
},
{
"course_name": "basic2",
"completed": True,
"students": [{"name": "Kate"}, {"name": "Miranda"}],
},
],
}
response = client.post("/departments/", data=json.dumps(payload))
department = Department(**response.json())
assert department.id is not None
assert len(department.courses) == 2
assert department.department_name == "Ormar"
assert department.courses[0].course_name == "basic1"
assert department.courses[0].completed
assert department.courses[1].course_name == "basic2"
assert department.courses[1].completed
response = client.get("/departments/Ormar")
response2 = client.get("/departments/Ormar/second")
assert response.json() == response2.json() == payload
response3 = client.get("/departments/Ormar/exclude")
assert response3.json() == {"department_name": "Ormar"}

View File

@ -0,0 +1,148 @@
import json
from typing import List, Optional
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
class Author(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
first_name: str = ormar.String(max_length=80)
last_name: str = ormar.String(max_length=80)
class Category(ormar.Model):
class Meta(BaseMeta):
tablename = "categories"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40)
class Post(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories = ormar.ManyToMany(Category, skip_reverse=True)
author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@app.post("/categories/", response_model=Category)
async def create_category(category: Category):
await category.save()
await category.save_related(follow=True, save_all=True)
return category
@app.post("/posts/", response_model=Post)
async def create_post(post: Post):
if post.author:
await post.author.save()
await post.save()
await post.save_related(follow=True, save_all=True)
for category in [cat for cat in post.categories]:
await post.categories.add(category)
return post
@app.get("/categories/", response_model=List[Category])
async def get_categories():
return await Category.objects.select_related("posts").all()
@app.get("/posts/", response_model=List[Post])
async def get_posts():
posts = await Post.objects.select_related(["categories", "author"]).all()
return posts
def test_queries():
client = TestClient(app)
with client as client:
right_category = {"name": "Test category"}
wrong_category = {"name": "Test category2", "posts": [{"title": "Test Post"}]}
# cannot add posts if skipped, will be ignored (with extra=ignore by default)
response = client.post("/categories/", data=json.dumps(wrong_category))
assert response.status_code == 200
response = client.get("/categories/")
assert response.status_code == 200
assert not "posts" in response.json()
categories = [Category(**x) for x in response.json()]
assert categories[0] is not None
assert categories[0].name == "Test category2"
response = client.post("/categories/", data=json.dumps(right_category))
assert response.status_code == 200
response = client.get("/categories/")
assert response.status_code == 200
categories = [Category(**x) for x in response.json()]
assert categories[1] is not None
assert categories[1].name == "Test category"
right_post = {
"title": "ok post",
"author": {"first_name": "John", "last_name": "Smith"},
"categories": [{"name": "New cat"}],
}
response = client.post("/posts/", data=json.dumps(right_post))
assert response.status_code == 200
Category.__config__.extra = "allow"
response = client.get("/posts/")
assert response.status_code == 200
posts = [Post(**x) for x in response.json()]
assert posts[0].title == "ok post"
assert posts[0].author.first_name == "John"
assert posts[0].categories[0].name == "New cat"
wrong_category = {"name": "Test category3", "posts": [{"title": "Test Post"}]}
# cannot add posts if skipped, will be error with extra forbid
Category.__config__.extra = "forbid"
response = client.post("/categories/", data=json.dumps(wrong_category))
assert response.status_code == 422

View File

@ -123,6 +123,16 @@ async def get_test_5(thing_id: UUID):
return await Thing.objects.all(other_thing__id=thing_id)
@app.get(
"/test/error", response_model=List[Thing], response_model_exclude={"other_thing"}
)
async def get_weakref():
ots = await OtherThing.objects.all()
ot = ots[0]
ts = await ot.things.all()
return ts
def test_endpoints():
client = TestClient(app)
with client:
@ -145,3 +155,7 @@ def test_endpoints():
resp5 = client.get(f"/test/5/{ot.id}")
assert resp5.status_code == 200
assert len(resp5.json()) == 3
resp6 = client.get("/test/error")
assert resp6.status_code == 200
assert len(resp6.json()) == 3

View File

@ -0,0 +1,34 @@
from typing import List
import databases
import sqlalchemy
from pydantic import PrivateAttr
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class Subscription(ormar.Model):
class Meta(BaseMeta):
tablename = "subscriptions"
id: int = ormar.Integer(primary_key=True)
stripe_subscription_id: str = ormar.String(nullable=False, max_length=256)
_add_payments: List[str] = PrivateAttr(default_factory=list)
def add_payment(self, payment: str):
self._add_payments.append(payment)
def test_private_attribute():
sub = Subscription(stripe_subscription_id="2312312sad231")
sub.add_payment("test")

View File

@ -119,7 +119,7 @@ async def test_saving_many_to_many():
assert count == 0
count = await hq.save_related(save_all=True)
assert count == 2
assert count == 3
hq.nicks[0].name = "Kabucha"
hq.nicks[1].name = "Kabucha2"

View File

@ -0,0 +1,256 @@
from typing import List
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class CringeLevel(ormar.Model):
class Meta:
tablename = "levels"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class NickName(ormar.Model):
class Meta:
tablename = "nicks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
is_lame: bool = ormar.Boolean(nullable=True)
level: CringeLevel = ormar.ForeignKey(CringeLevel)
class NicksHq(ormar.Model):
class Meta:
tablename = "nicks_x_hq"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
new_field: str = ormar.String(max_length=200, nullable=True)
class HQ(ormar.Model):
class Meta:
tablename = "hqs"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
nicks: List[NickName] = ormar.ManyToMany(NickName, through=NicksHq)
class Company(ormar.Model):
class Meta:
tablename = "companies"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
founded: int = ormar.Integer(nullable=True)
hq: HQ = ormar.ForeignKey(HQ, related_name="companies")
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_saving_related_reverse_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {"companies": [{"name": "Banzai"}], "name": "Main"}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 2
hq_check = await HQ.objects.select_related("companies").get()
assert hq_check.pk is not None
assert hq_check.name == "Main"
assert len(hq_check.companies) == 1
assert hq_check.companies[0].name == "Banzai"
assert hq_check.companies[0].pk is not None
@pytest.mark.asyncio
async def test_saving_related_reverse_fk_multiple():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"companies": [{"name": "Banzai"}, {"name": "Yamate"}],
"name": "Main",
}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 3
hq_check = await HQ.objects.select_related("companies").get()
assert hq_check.pk is not None
assert hq_check.name == "Main"
assert len(hq_check.companies) == 2
assert hq_check.companies[0].name == "Banzai"
assert hq_check.companies[0].pk is not None
assert hq_check.companies[1].name == "Yamate"
assert hq_check.companies[1].pk is not None
@pytest.mark.asyncio
async def test_saving_related_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {"hq": {"name": "Main"}, "name": "Banzai"}
comp = Company(**payload)
count = await comp.save_related(follow=True, save_all=True)
assert count == 2
comp_check = await Company.objects.select_related("hq").get()
assert comp_check.pk is not None
assert comp_check.name == "Banzai"
assert comp_check.hq.name == "Main"
assert comp_check.hq.pk is not None
@pytest.mark.asyncio
async def test_saving_many_to_many_wo_through():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{"name": "Bazinga0", "is_lame": False},
{"name": "Bazinga20", "is_lame": True},
],
}
hq = HQ(**payload)
count = await hq.save_related()
assert count == 3
hq_check = await HQ.objects.select_related("nicks").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[1].name == "Bazinga20"
@pytest.mark.asyncio
async def test_saving_many_to_many_with_through():
async with database:
async with database.transaction(force_rollback=True):
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{
"name": "Bazinga0",
"is_lame": False,
"nickshq": {"new_field": "test"},
},
{
"name": "Bazinga20",
"is_lame": True,
"nickshq": {"new_field": "test2"},
},
],
}
hq = HQ(**payload)
count = await hq.save_related()
assert count == 3
hq_check = await HQ.objects.select_related("nicks").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[0].nickshq.new_field == "test"
assert hq_check.nicks[1].name == "Bazinga20"
assert hq_check.nicks[1].nickshq.new_field == "test2"
@pytest.mark.asyncio
async def test_saving_nested_with_m2m_and_rev_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{"name": "Bazinga0", "is_lame": False, "level": {"name": "High"}},
{"name": "Bazinga20", "is_lame": True, "level": {"name": "Low"}},
],
}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 5
hq_check = await HQ.objects.select_related("nicks__level").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[0].level.name == "High"
assert hq_check.nicks[1].name == "Bazinga20"
assert hq_check.nicks[1].level.name == "Low"
@pytest.mark.asyncio
async def test_saving_nested_with_m2m_and_rev_fk_and_through():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"hq": {
"name": "Yoko",
"nicks": [
{
"name": "Bazinga0",
"is_lame": False,
"nickshq": {"new_field": "test"},
"level": {"name": "High"},
},
{
"name": "Bazinga20",
"is_lame": True,
"nickshq": {"new_field": "test2"},
"level": {"name": "Low"},
},
],
},
"name": "Main",
}
company = Company(**payload)
count = await company.save_related(follow=True, save_all=True)
assert count == 6
company_check = await Company.objects.select_related(
"hq__nicks__level"
).get()
assert company_check.pk is not None
assert company_check.name == "Main"
assert company_check.hq.name == "Yoko"
assert len(company_check.hq.nicks) == 2
assert company_check.hq.nicks[0].name == "Bazinga0"
assert company_check.hq.nicks[0].nickshq.new_field == "test"
assert company_check.hq.nicks[0].level.name == "High"
assert company_check.hq.nicks[1].name == "Bazinga20"
assert company_check.hq.nicks[1].level.name == "Low"
assert company_check.hq.nicks[1].nickshq.new_field == "test2"

View File

@ -0,0 +1,84 @@
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
class Course(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
course_name: str = ormar.String(max_length=100)
class Student(ormar.Model):
class Meta:
database = database
metadata = metadata
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
courses = ormar.ManyToMany(
Course,
through_relation_name="student_id",
through_reverse_relation_name="course_id",
)
# create db and tables
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_tables_columns():
through_meta = Student.Meta.model_fields["courses"].through.Meta
assert "course_id" in through_meta.table.c
assert "student_id" in through_meta.table.c
assert "course_id" in through_meta.model_fields
assert "student_id" in through_meta.model_fields
@pytest.mark.asyncio
async def test_working_with_changed_through_names():
async with database:
async with database.transaction(force_rollback=True):
to_save = {
"course_name": "basic1",
"students": [{"name": "Jack"}, {"name": "Abi"}],
}
await Course(**to_save).save_related(follow=True, save_all=True)
course_check = await Course.objects.select_related("students").get()
assert course_check.course_name == "basic1"
assert course_check.students[0].name == "Jack"
assert course_check.students[1].name == "Abi"
students = await course_check.students.all()
assert len(students) == 2
student = await course_check.students.get(name="Jack")
assert student.name == "Jack"
students = await Student.objects.select_related("courses").all(
courses__course_name="basic1"
)
assert len(students) == 2
course_check = (
await Course.objects.select_related("students")
.order_by("students__name")
.get()
)
assert course_check.students[0].name == "Abi"
assert course_check.students[1].name == "Jack"

View File

@ -0,0 +1,223 @@
from typing import List, Optional
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
class Author(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
first_name: str = ormar.String(max_length=80)
last_name: str = ormar.String(max_length=80)
class Category(ormar.Model):
class Meta(BaseMeta):
tablename = "categories"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=40)
class Post(ormar.Model):
class Meta(BaseMeta):
pass
id: int = ormar.Integer(primary_key=True)
title: str = ormar.String(max_length=200)
categories: Optional[List[Category]] = ormar.ManyToMany(Category, skip_reverse=True)
author: Optional[Author] = ormar.ForeignKey(Author, skip_reverse=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.fixture(scope="function")
async def cleanup():
yield
async with database:
PostCategory = Post.Meta.model_fields["categories"].through
await PostCategory.objects.delete(each=True)
await Post.objects.delete(each=True)
await Category.objects.delete(each=True)
await Author.objects.delete(each=True)
def test_model_definition():
category = Category(name="Test")
author = Author(first_name="Test", last_name="Author")
post = Post(title="Test Post", author=author)
post.categories = category
assert post.categories[0] == category
assert post.author == author
with pytest.raises(AttributeError):
assert author.posts
with pytest.raises(AttributeError):
assert category.posts
assert "posts" not in category._orm
@pytest.mark.asyncio
async def test_assigning_related_objects(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
# Add a category to a post.
await post.categories.add(news)
# other way is disabled
with pytest.raises(AttributeError):
await news.posts.add(post)
assert await post.categories.get_or_none(name="no exist") is None
assert await post.categories.get_or_none(name="News") == news
# Creating columns object from instance:
await post.categories.create(name="Tips")
assert len(post.categories) == 2
post_categories = await post.categories.all()
assert len(post_categories) == 2
category = await Category.objects.select_related("posts").get(name="News")
with pytest.raises(AttributeError):
assert category.posts
@pytest.mark.asyncio
async def test_quering_of_related_model_works_but_no_result(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
await post.categories.add(news)
post_categories = await post.categories.all()
assert len(post_categories) == 1
assert "posts" not in post.dict().get("categories", [])[0]
assert news == await post.categories.get(name="News")
posts_about_python = await Post.objects.filter(categories__name="python").all()
assert len(posts_about_python) == 0
# relation not in dict
category = (
await Category.objects.select_related("posts")
.filter(posts__author=guido)
.get()
)
assert category == news
assert "posts" not in category.dict()
# relation not in json
category2 = (
await Category.objects.select_related("posts")
.filter(posts__author__first_name="Guido")
.get()
)
assert category2 == news
assert "posts" not in category2.json()
assert "posts" not in Category.schema().get("properties")
@pytest.mark.asyncio
async def test_removal_of_the_relations(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
post = await Post.objects.create(title="Hello, M2M", author=guido)
news = await Category.objects.create(name="News")
await post.categories.add(news)
assert len(await post.categories.all()) == 1
await post.categories.remove(news)
assert len(await post.categories.all()) == 0
with pytest.raises(AttributeError):
await news.posts.add(post)
with pytest.raises(AttributeError):
await news.posts.remove(post)
await post.categories.add(news)
await post.categories.clear()
assert len(await post.categories.all()) == 0
await post.categories.add(news)
await news.delete()
assert len(await post.categories.all()) == 0
@pytest.mark.asyncio
async def test_selecting_related(cleanup):
async with database:
guido = await Author.objects.create(first_name="Guido", last_name="Van Rossum")
guido2 = await Author.objects.create(
first_name="Guido2", last_name="Van Rossum"
)
post = await Post.objects.create(title="Hello, M2M", author=guido)
post2 = await Post.objects.create(title="Bye, M2M", author=guido2)
news = await Category.objects.create(name="News")
recent = await Category.objects.create(name="Recent")
await post.categories.add(news)
await post.categories.add(recent)
await post2.categories.add(recent)
assert len(await post.categories.all()) == 2
assert (await post.categories.limit(1).all())[0] == news
assert (await post.categories.offset(1).limit(1).all())[0] == recent
assert await post.categories.first() == news
assert await post.categories.exists()
# still can order
categories = (
await Category.objects.select_related("posts")
.order_by("posts__title")
.all()
)
assert categories[0].name == "Recent"
assert categories[1].name == "News"
# still can filter
categories = await Category.objects.filter(posts__title="Bye, M2M").all()
assert categories[0].name == "Recent"
assert len(categories) == 1
# same for reverse fk
authors = (
await Author.objects.select_related("posts").order_by("posts__title").all()
)
assert authors[0].first_name == "Guido2"
assert authors[1].first_name == "Guido"
authors = await Author.objects.filter(posts__title="Bye, M2M").all()
assert authors[0].first_name == "Guido2"
assert len(authors) == 1