diff --git a/.coverage b/.coverage index 317b35b..a5c0d9f 100644 Binary files a/.coverage and b/.coverage differ diff --git a/README.md b/README.md index 5098d19..d6ff26c 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,9 @@ await Track.objects.create(album=malibu, title="The Bird", position=1) await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) await Track.objects.create(album=malibu, title="The Waters", position=3) -fantasies = await Album.objects.create(name="Fantasies") +# alternative creation of object divided into 2 steps +fantasies = Album.objects.create(name="Fantasies") +await fantasies.save() await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) await Track.objects.create(album=fantasies, title="Sick Muse", position=2) @@ -137,12 +139,33 @@ tracks = await Track.objects.limit(1).all() assert len(tracks) == 1 ``` -## Data types +## Ormar Specification + +### QuerySet methods + +* `create(**kwargs): -> Model` +* `get(**kwargs): -> Model` +* `get_or_create(**kwargs) -> Model` +* `update(each: bool = False, **kwargs) -> int` +* `update_or_create(**kwargs) -> Model` +* `bulk_create(objects: List[Model]) -> None` +* `bulk_update(objects: List[Model], columns: List[str] = None) -> None` +* `delete(each: bool = False, **kwargs) -> int` +* `all(self, **kwargs) -> List[Optional[Model]]` +* `filter(**kwargs) -> QuerySet` +* `exclude(**kwargs) -> QuerySet` +* `select_related(related: Union[List, str]) -> QuerySet` +* `limit(limit_count: int) -> QuerySet` +* `offset(offset: int) -> QuerySet` +* `count() -> int` +* `exists() -> bool` +* `fields(columns: Union[List, str]) -> QuerySet` + #### Relation types -* One to many - with `ForeignKey` -* Many to many - with `Many2Many` +* One to many - with `ForeignKey(to: Model)` +* Many to many - with `ManyToMany(to: Model, through: Model)` #### Model fields types @@ -161,7 +184,7 @@ Available Model Fields (with required args - optional ones in docs): * `Decimal(scale, precision)` * `UUID()` * `ForeignKey(to)` -* `Many2Many(to, through)` +* `ManyToMany(to, through)` ### Available fields options The following keyword arguments are supported on all field types. diff --git a/docs/fastapi.md b/docs/fastapi.md index 399a2c3..fbb6f11 100644 --- a/docs/fastapi.md +++ b/docs/fastapi.md @@ -6,97 +6,66 @@ you need to do is substitute pydantic models with ormar models. Here you can find a very simple sample application code. +## Imports and initialization + +First take care of the imports and initialization +```python hl_lines="1-12" +--8<-- "../docs_src/fastapi/docs001.py" +``` + +## Database connection + +Next define startup and shutdown events (or use middleware) +- note that this is `databases` specific setting not the ormar one +```python hl_lines="15-26" +--8<-- "../docs_src/fastapi/docs001.py" +``` + +!!!info + You can read more on connecting to databases in [fastapi][fastapi] documentation + +## Models definition + +Define ormar models with appropriate fields. + +Those models will be used insted of pydantic ones. +```python hl_lines="29-47" +--8<-- "../docs_src/fastapi/docs001.py" +``` + +!!!tip + You can read more on defining `Models` in [models][models] section. + +## Fastapi endpoints definition + +Define your desired endpoints, note how `ormar` models are used both +as `response_model` and as a requests parameters. + +```python hl_lines="50-77" +--8<-- "../docs_src/fastapi/docs001.py" +``` + +!!!note + Note how ormar `Model` methods like save() are available straight out of the box after fastapi initializes it for you. + +!!!note + Note that you can return a `Model` (or list of `Models`) directly - fastapi will jsonize it for you + +## Test the application + +Here you have a sample test that will prove that everything works as intended. + ```python -from typing import List - -import databases -import pytest -import sqlalchemy -from fastapi import FastAPI -from starlette.testclient import TestClient - -import ormar -from tests.settings import DATABASE_URL - -app = FastAPI() -metadata = sqlalchemy.MetaData() -database = databases.Database(DATABASE_URL, force_rollback=True) -app.state.database = database - -# define startup and shutdown events -@app.on_event("startup") -async def startup() -> None: - database_ = app.state.database - if not database_.is_connected: - await database_.connect() - - -@app.on_event("shutdown") -async def shutdown() -> None: - database_ = app.state.database - if database_.is_connected: - await database_.disconnect() - -# define ormar models -class Category(ormar.Model): - class Meta: - tablename = "categories" - metadata = metadata - database = database - - id: ormar.Integer(primary_key=True) - name: ormar.String(max_length=100) - - -class Item(ormar.Model): - class Meta: - tablename = "items" - metadata = metadata - database = database - - id: ormar.Integer(primary_key=True) - name: ormar.String(max_length=100) - category: ormar.ForeignKey(Category, nullable=True) - -# define endpoints in fastapi -@app.get("/items/", response_model=List[Item]) -async def get_items(): - items = await Item.objects.select_related("category").all() - # not that you can return a model directly - fastapi will json-ize it - return items - - -@app.post("/items/", response_model=Item) -async def create_item(item: Item): - # note how ormar methods like save() are available streight out of the box - await item.save() - return item - - -@app.post("/categories/", response_model=Category) -async def create_category(category: Category): - await category.save() - return category - - -@app.put("/items/{item_id}") -async def get_item(item_id: int, item: Item): - # you can work both with item_id or item - item_db = await Item.objects.get(pk=item_id) - return await item_db.update(**item.dict()) - - -@app.delete("/items/{item_id}") -async def delete_item(item_id: int, item: Item): - item_db = await Item.objects.get(pk=item_id) - return {"deleted_rows": await item_db.delete()} # here is a sample test to check the working of the ormar with fastapi + +from starlette.testclient import TestClient + def test_all_endpoints(): # note that TestClient is only sync, don't use asyns here client = TestClient(app) # note that you need to connect to database manually - # or use client as contextmanager + # or use client as contextmanager during tests with client as client: response = client.post("/categories/", json={"name": "test cat"}) category = response.json() @@ -123,4 +92,10 @@ def test_all_endpoints(): response = client.get("/items/") items = response.json() assert len(items) == 0 -``` \ No newline at end of file +``` + +!!!info + You can read more on testing fastapi in [fastapi][fastapi] docs. + +[fastapi]: https://fastapi.tiangolo.com/ +[models]: ./models.md \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 5098d19..ca77b54 100644 --- a/docs/index.md +++ b/docs/index.md @@ -97,7 +97,9 @@ await Track.objects.create(album=malibu, title="The Bird", position=1) await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) await Track.objects.create(album=malibu, title="The Waters", position=3) -fantasies = await Album.objects.create(name="Fantasies") +# alternative creation of object divided into 2 steps +fantasies = Album.objects.create(name="Fantasies") +await fantasies.save() await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) await Track.objects.create(album=fantasies, title="Sick Muse", position=2) @@ -137,12 +139,33 @@ tracks = await Track.objects.limit(1).all() assert len(tracks) == 1 ``` -## Data types +## Ormar Specification + +### QuerySet methods + +* `create(**kwargs): -> Model` +* `get(**kwargs): -> Model` +* `get_or_create(**kwargs) -> Model` +* `update(each: bool = False, **kwargs) -> int` +* `update_or_create(**kwargs) -> Model` +* `bulk_create(objects: List[Model]) -> None` +* `bulk_update(objects: List[Model], columns: List[str] = None) -> None` +* `delete(each: bool = False, **kwargs) -> int` +* `all(self, **kwargs) -> List[Optional[Model]]` +* `filter(**kwargs) -> QuerySet` +* `exclude(**kwargs) -> QuerySet` +* `select_related(related: Union[List, str]) -> QuerySet` +* `limit(limit_count: int) -> QuerySet` +* `offset(offset: int) -> QuerySet` +* `count() -> int` +* `exists() -> bool` +* `fields(columns: Union[List, str]) -> QuerySet` + #### Relation types -* One to many - with `ForeignKey` -* Many to many - with `Many2Many` +* One to many - with `ForeignKey(to: Model)` +* Many to many - with `ManyToMany(to: Model, through: Model)` #### Model fields types @@ -161,7 +184,7 @@ Available Model Fields (with required args - optional ones in docs): * `Decimal(scale, precision)` * `UUID()` * `ForeignKey(to)` -* `Many2Many(to, through)` +* `ManyToMany(to, through)` ### Available fields options The following keyword arguments are supported on all field types. @@ -173,6 +196,7 @@ The following keyword arguments are supported on all field types. * `index: bool` * `unique: bool` * `choices: typing.Sequence` + * `name: str` All fields are required unless one of the following is set: diff --git a/docs/models.md b/docs/models.md index 57b7998..6e2c73b 100644 --- a/docs/models.md +++ b/docs/models.md @@ -37,7 +37,27 @@ You can disable by passing `autoincremant=False`. id: ormar.Integer(primary_key=True, autoincrement=False) ``` -Names of the fields will be used for both the underlying `pydantic` model and `sqlalchemy` table. +### Fields names vs Column names + +By default names of the fields will be used for both the underlying `pydantic` model and `sqlalchemy` table. + +If for whatever reason you prefer to change the name in the database but keep the name in the model you can do this +with specifying `name` parameter during Field declaration + +Here you have a sample model with changed names +```Python hl_lines="16-19" +--8<-- "../docs_src/models/docs008.py" +``` + +Note that you can also change the ForeignKey column name +```Python hl_lines="9" +--8<-- "../docs_src/models/docs009.py" +``` + +But for now you cannot change the ManyToMany column names as they go through other Model anyway. +```Python hl_lines="18" +--8<-- "../docs_src/models/docs010.py" +``` ### Dependencies @@ -128,7 +148,9 @@ Each model has a `QuerySet` initialised as `objects` parameter ### load By default when you query a table without prefetching related models, the ormar will still construct -your related models, but populate them only with the pk value. +your related models, but populate them only with the pk value. You can load the related model by calling `load()` method. + +`load()` can also be used to refresh the model from the database (if it was changed by some other process). ```python track = await Track.objects.get(name='The Bird') @@ -142,10 +164,36 @@ track.album.name # will return 'Malibu' ### save +You can create new models by using `QuerySet.create()` method or by initializing your model as a normal pydantic model +and later calling `save()` method. + +`save()` can also be used to persist changes that you made to the model. + +```python +track = Track(name='The Bird') +await track.save() # will persist the model in database +``` + ### delete +You can delete models by using `QuerySet.delete()` method or by using your model and calling `delete()` method. + +```python +track = await Track.objects.get(name='The Bird') +await track.delete() # will delete the model from database +``` + +!!!tip + Note that that `track` object stays the same, only record in the database is removed. + ### update +You can delete models by using `QuerySet.update()` method or by using your model and calling `update()` method. + +```python +track = await Track.objects.get(name='The Bird') +await track.update(name='The Bird Strikes Again') +``` ## Internals diff --git a/docs/relations.md b/docs/relations.md index 0d227e9..fd6a3e1 100644 --- a/docs/relations.md +++ b/docs/relations.md @@ -85,9 +85,9 @@ Finally you can explicitly set it to None (default behavior if no value passed). Otherwise an IntegrityError will be raised by your database driver library. -### Many2Many +### ManyToMany -`Many2Many(to, through)` has required parameters `to` and `through` that takes target and relation `Model` classes. +`ManyToMany(to, through)` has required parameters `to` and `through` that takes target and relation `Model` classes. Sqlalchemy column and Type are automatically taken from target `Model`. @@ -131,7 +131,7 @@ assert len(await post.categories.all()) == 2 ``` !!!note - Note that when accessing QuerySet API methods through Many2Many relation you don't + Note that when accessing QuerySet API methods through ManyToMany relation you don't need to use objects attribute like in normal queries. To learn more about available QuerySet methods visit [queries][queries] @@ -146,7 +146,7 @@ await news.posts.clear() #### All other queryset methods -When access directly the related `Many2Many` field returns the list of related models. +When access directly the related `ManyToMany` field returns the list of related models. But at the same time it exposes full QuerySet API, so you can filter, create, select related etc. diff --git a/docs/releases.md b/docs/releases.md new file mode 100644 index 0000000..c674fdb --- /dev/null +++ b/docs/releases.md @@ -0,0 +1,45 @@ +# 0.3.8 + +* Added possibility to provide alternative database column names with name parameter to all fields. +* Fix bug with selecting related ManyToMany fields with `fields()` if they are empty. +* Updated documentation + +# 0.3.7 + +* Publish documentation and update readme + +# 0.3.6 + +* Add fields() method to limit the selected columns from database - only nullable columns can be excluded. +* Added UniqueColumns and constraints list in model Meta to build unique constraints on list of columns. +* Added UUID field type based on Char(32) column type. + +# 0.3.5 + +* Added bulk_create and bulk_update for operations on multiple objects. + +# 0.3.4 + +Add queryset level methods +* delete +* update +* get_or_create +* update_or_create + +# 0.3.3 + +* Add additional filters - startswith and endswith + +# 0.3.2 + +* Add choices parameter to all fields - limiting the accepted values to ones provided + +# 0.3.1 + +* Added exclude to filter where not conditions. +* Added tests for mysql and postgres with fixes for postgres. +* Rafactors and cleanup. + +# 0.3.0 + +* Added ManyToMany field and support for many to many relations \ No newline at end of file diff --git a/docs/testing.md b/docs/testing.md deleted file mode 100644 index e69de29..0000000 diff --git a/docs_src/fastapi/docs001.py b/docs_src/fastapi/docs001.py new file mode 100644 index 0000000..a1d13c5 --- /dev/null +++ b/docs_src/fastapi/docs001.py @@ -0,0 +1,77 @@ +from typing import List + +import databases +import sqlalchemy +from fastapi import FastAPI + +import ormar + +app = FastAPI() +metadata = sqlalchemy.MetaData() +database = databases.Database("sqlite:///test.db", force_rollback=True) +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + + +class Item(ormar.Model): + class Meta: + tablename = "items" + metadata = metadata + database = database + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) + + +@app.get("/items/", response_model=List[Item]) +async def get_items(): + items = await Item.objects.select_related("category").all() + return items + + +@app.post("/items/", response_model=Item) +async def create_item(item: Item): + await item.save() + return item + + +@app.post("/categories/", response_model=Category) +async def create_category(category: Category): + await category.save() + return category + + +@app.put("/items/{item_id}") +async def get_item(item_id: int, item: Item): + item_db = await Item.objects.get(pk=item_id) + return await item_db.update(**item.dict()) + + +@app.delete("/items/{item_id}") +async def delete_item(item_id: int, item: Item): + item_db = await Item.objects.get(pk=item_id) + return {"deleted_rows": await item_db.delete()} diff --git a/docs_src/models/docs008.py b/docs_src/models/docs008.py new file mode 100644 index 0000000..9a3d063 --- /dev/null +++ b/docs_src/models/docs008.py @@ -0,0 +1,19 @@ +import databases +import sqlalchemy + +import ormar + +database = databases.Database("sqlite:///test.db", force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Child(ormar.Model): + class Meta: + tablename = "children" + metadata = metadata + database = database + + id: ormar.Integer(name='child_id', primary_key=True) + first_name: ormar.String(name='fname', max_length=100) + last_name: ormar.String(name='lname', max_length=100) + born_year: ormar.Integer(name='year_born', nullable=True) diff --git a/docs_src/models/docs009.py b/docs_src/models/docs009.py new file mode 100644 index 0000000..0204feb --- /dev/null +++ b/docs_src/models/docs009.py @@ -0,0 +1,9 @@ +class Album(ormar.Model): + class Meta: + tablename = "music_albums" + metadata = metadata + database = database + + id: ormar.Integer(name='album_id', primary_key=True) + name: ormar.String(name='album_name', max_length=100) + artist: ormar.ForeignKey(Artist, name='artist_id') diff --git a/docs_src/models/docs010.py b/docs_src/models/docs010.py new file mode 100644 index 0000000..57febef --- /dev/null +++ b/docs_src/models/docs010.py @@ -0,0 +1,18 @@ +class ArtistChildren(ormar.Model): + class Meta: + tablename = "children_x_artists" + metadata = metadata + database = database + + +class Artist(ormar.Model): + class Meta: + tablename = "artists" + metadata = metadata + database = database + + id: ormar.Integer(name='artist_id', primary_key=True) + first_name: ormar.String(name='fname', max_length=100) + last_name: ormar.String(name='lname', max_length=100) + born_year: ormar.Integer(name='year') + children: ormar.ManyToMany(Child, through=ArtistChildren) diff --git a/mkdocs.yml b/mkdocs.yml index 26a8ed4..a7b99e3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -9,6 +9,7 @@ nav: - Queries: queries.md - Use with Fastapi: fastapi.md - Contributing: contributing.md + - Release Notes: releases.md repo_name: collerek/ormar repo_url: https://github.com/collerek/ormar google_analytics: diff --git a/ormar/__init__.py b/ormar/__init__.py index 16db570..159fda5 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -28,7 +28,7 @@ class UndefinedType: # pragma no cover Undefined = UndefinedType() -__version__ = "0.3.7" +__version__ = "0.3.8" __all__ = [ "Integer", "BigInteger", diff --git a/ormar/fields/base.py b/ormar/fields/base.py index d3f5e6c..88e1313 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -64,7 +64,7 @@ class BaseField: @classmethod def get_column(cls, name: str) -> sqlalchemy.Column: return sqlalchemy.Column( - name, + cls.name or name, cls.column_type, *cls.constraints, primary_key=cls.primary_key, diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 957c9e5..2959f32 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -38,7 +38,7 @@ def ForeignKey( # noqa CFQ002 onupdate: str = None, ondelete: str = None, ) -> Type["ForeignKeyField"]: - fk_string = to.Meta.tablename + "." + to.Meta.pkname + fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname) to_field = to.__fields__[to.Meta.pkname] namespace = dict( to=to, diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 587d92f..5462865 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -12,7 +12,7 @@ from ormar.fields.base import BaseField # noqa I101 def is_field_nullable( - nullable: Optional[bool], default: Any, server_default: Any + nullable: Optional[bool], default: Any, server_default: Any ) -> bool: if nullable is None: return default is not None or server_default is not None @@ -61,15 +61,15 @@ class String(ModelFieldFactory): _type = str def __new__( # type: ignore # noqa CFQ002 - cls, - *, - allow_blank: bool = True, - strip_whitespace: bool = False, - min_length: int = None, - max_length: int = None, - curtail_length: int = None, - regex: str = None, - **kwargs: Any + cls, + *, + allow_blank: bool = True, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + **kwargs: Any ) -> Type[BaseField]: # type: ignore kwargs = { **kwargs, @@ -79,7 +79,7 @@ class String(ModelFieldFactory): if k not in ["cls", "__class__", "kwargs"] }, } - kwargs['allow_blank'] = kwargs.get('nullable', True) + kwargs["allow_blank"] = kwargs.get("nullable", True) return super().__new__(cls, **kwargs) @classmethod @@ -100,12 +100,12 @@ class Integer(ModelFieldFactory): _type = int def __new__( # type: ignore - cls, - *, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - **kwargs: Any + cls, + *, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + **kwargs: Any ) -> Type[BaseField]: autoincrement = kwargs.pop("autoincrement", None) autoincrement = ( @@ -135,7 +135,7 @@ class Text(ModelFieldFactory): _type = str def __new__( # type: ignore - cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any + cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any ) -> Type[BaseField]: kwargs = { **kwargs, @@ -145,7 +145,7 @@ class Text(ModelFieldFactory): if k not in ["cls", "__class__", "kwargs"] }, } - kwargs['allow_blank'] = kwargs.get('nullable', True) + kwargs["allow_blank"] = kwargs.get("nullable", True) return super().__new__(cls, **kwargs) @classmethod @@ -158,12 +158,12 @@ class Float(ModelFieldFactory): _type = float def __new__( # type: ignore - cls, - *, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - **kwargs: Any + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + **kwargs: Any ) -> Type[BaseField]: kwargs = { **kwargs, @@ -232,12 +232,12 @@ class BigInteger(Integer): _type = int def __new__( # type: ignore - cls, - *, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - **kwargs: Any + cls, + *, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + **kwargs: Any ) -> Type[BaseField]: autoincrement = kwargs.pop("autoincrement", None) autoincrement = ( @@ -267,16 +267,16 @@ class Decimal(ModelFieldFactory): _type = decimal.Decimal def __new__( # type: ignore # noqa CFQ002 - cls, - *, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - precision: int = None, - scale: int = None, - max_digits: int = None, - decimal_places: int = None, - **kwargs: Any + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + **kwargs: Any ) -> Type[BaseField]: kwargs = { **kwargs, diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 0bd711d..d7c57b9 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -41,7 +41,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -50,11 +50,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -65,7 +65,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -73,10 +73,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -91,7 +91,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, name=model.get_name(), ondelete="CASCADE" @@ -108,7 +108,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -120,13 +120,13 @@ def create_pydantic_field( def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), - model.Meta.table.columns.get(model.Meta.pkname).type, + model.Meta.table.columns.get(model.get_column_alias(model.Meta.pkname)).type, sqlalchemy.schema.ForeignKey( - model.Meta.tablename + "." + model.Meta.pkname, + model.Meta.tablename + "." + model.get_column_alias(model.Meta.pkname), ondelete="CASCADE", onupdate="CASCADE", ), @@ -136,7 +136,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -146,7 +146,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -160,9 +160,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field_name)) register_relation_in_alias_manager(table_name, field) @@ -170,7 +170,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -179,7 +179,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - type_: Type[BaseField], field: str, attrs: dict + type_: Type[BaseField], field: str, attrs: dict ) -> dict: def_value = type_.default_value() curr_def_value = attrs.get(field, "NONE") @@ -208,7 +208,7 @@ def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: def populate_meta_orm_model_fields( - attrs: dict, new_model: Type["Model"] + attrs: dict, new_model: Type["Model"] ) -> Type["Model"]: model_fields = { field_name: field @@ -220,10 +220,12 @@ def populate_meta_orm_model_fields( def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" - new_model.Meta.tablename = new_model.Meta.tablename if hasattr(new_model.Meta, 'tablename') else tablename + new_model.Meta.tablename = ( + new_model.Meta.tablename if hasattr(new_model.Meta, "tablename") else tablename + ) pkname: Optional[str] if hasattr(new_model.Meta, "columns"): @@ -244,7 +246,7 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( @@ -286,7 +288,7 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A def populate_choices_validators( # noqa CCR001 - model: Type["Model"], attrs: Dict + model: Type["Model"], attrs: Dict ) -> None: if model_initialized_and_has_model_fields(model): for _, field in model.Meta.model_fields.items(): @@ -299,7 +301,7 @@ def populate_choices_validators( # noqa CCR001 class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name diff --git a/ormar/models/model.py b/ormar/models/model.py index fd71efe..3805a00 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -43,7 +43,6 @@ class Model(NewBaseModel): if select_related: related_models = group_related_list(select_related) - # breakpoint() if ( previous_table and previous_table in cls.Meta.model_fields @@ -90,13 +89,13 @@ class Model(NewBaseModel): previous_table=previous_table, fields=fields, ) - item[first_part] = child + item[model_cls.get_column_name_from_alias(first_part)] = child else: model_cls = cls.Meta.model_fields[related].to child = model_cls.from_row( row, previous_table=previous_table, fields=fields ) - item[related] = child + item[model_cls.get_column_name_from_alias(related)] = child return item @@ -113,13 +112,16 @@ class Model(NewBaseModel): # databases does not keep aliases in Record for postgres, change to raw row source = row._row if isinstance(row, Record) else row - selected_columns = cls.own_table_columns(cls, fields or [], nested=nested) + selected_columns = cls.own_table_columns( + cls, fields or [], nested=nested, use_alias=True + ) for column in cls.Meta.table.columns: - if column.name not in item and column.name in selected_columns: + alias = cls.get_column_name_from_alias(column.name) + if alias not in item and alias in selected_columns: prefixed_name = ( f'{table_prefix + "_" if table_prefix else ""}{column.name}' ) - item[column.name] = source[prefixed_name] + item[alias] = source[prefixed_name] return item @@ -142,7 +144,8 @@ class Model(NewBaseModel): self.from_dict(new_values) self_fields = self._extract_model_db_fields() - self_fields.pop(self.Meta.pkname) + self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname)) + self_fields = self.objects._translate_columns_to_aliases(self_fields) expr = self.Meta.table.update().values(**self_fields) expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname)) @@ -162,5 +165,7 @@ class Model(NewBaseModel): raise ValueError( "Instance was deleted from database and cannot be refreshed" ) - self.from_dict(dict(row)) + kwargs = dict(row) + kwargs = self.objects._translate_aliases_to_columns(kwargs) + self.from_dict(kwargs) return self diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 0dadbb2..760232d 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -3,7 +3,7 @@ from typing import Dict, List, Set, TYPE_CHECKING, Type, TypeVar, Union import ormar from ormar.exceptions import RelationshipInstanceError -from ormar.fields import BaseField +from ormar.fields import BaseField, ManyToManyField from ormar.fields.foreign_key import ForeignKeyField from ormar.models.metaclass import ModelMeta @@ -35,7 +35,7 @@ class ModelTableProxy: return self_fields @classmethod - def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: + def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: # noqa CCR001 for field in cls.extract_related_names(): field_value = model_dict.get(field, None) if field_value is not None: @@ -43,10 +43,26 @@ class ModelTableProxy: target_pkname = target_field.to.Meta.pkname if isinstance(field_value, ormar.Model): model_dict[field] = getattr(field_value, target_pkname) - else: + elif field_value: model_dict[field] = field_value.get(target_pkname) + else: + model_dict.pop(field, None) return model_dict + @classmethod + def get_column_alias(cls, field_name: str) -> str: + field = cls.Meta.model_fields.get(field_name) + if field and field.name is not None and field.name != field_name: + return field.name + return field_name + + @classmethod + def get_column_name_from_alias(cls, alias: str) -> str: + for field_name, field in cls.Meta.model_fields.items(): + if field and field.name == alias: + return field_name + return alias # if not found it's not an alias but actual name + @classmethod def extract_related_names(cls) -> Set: related_names = set() @@ -62,6 +78,7 @@ class ModelTableProxy: if ( inspect.isclass(field) and issubclass(field, ForeignKeyField) + and not issubclass(field, ManyToManyField) and not field.virtual ): related_names.add(name) @@ -84,7 +101,9 @@ class ModelTableProxy: def _extract_model_db_fields(self) -> Dict: self_fields = self._extract_own_model_fields() self_fields = { - k: v for k, v in self_fields.items() if k in self.Meta.table.columns + k: v + for k, v in self_fields.items() + if self.get_column_alias(k) in self.Meta.table.columns } for field in self._extract_db_related_names(): target_pk_name = self.Meta.model_fields[field].to.Meta.pkname @@ -125,7 +144,7 @@ class ModelTableProxy: def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: merged_rows: List["Model"] = [] for index, model in enumerate(result_rows): - if index > 0 and model.pk == merged_rows[-1].pk: + if index > 0 and model is not None and model.pk == merged_rows[-1].pk: merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) else: merged_rows.append(model) @@ -151,30 +170,62 @@ class ModelTableProxy: return other @staticmethod - def own_table_columns( - model: Type["Model"], fields: List, nested: bool = False + def _get_not_nested_columns_from_fields( + model: Type["Model"], + fields: List, + column_names: List[str], + use_alias: bool = False, ) -> List[str]: - column_names = [col.name for col in model.Meta.table.columns] + fields = [model.get_column_alias(k) if not use_alias else k for k in fields] + columns = [name for name in fields if "__" not in name and name in column_names] + return columns + + @staticmethod + def _get_nested_columns_from_fields( + model: Type["Model"], fields: List, use_alias: bool = False, + ) -> List[str]: + model_name = f"{model.get_name()}__" + columns = [ + name[(name.find(model_name) + len(model_name)) :] # noqa: E203 + for name in fields + if f"{model.get_name()}__" in name + ] + columns = [model.get_column_alias(k) if not use_alias else k for k in columns] + return columns + + @staticmethod + def own_table_columns( + model: Type["Model"], + fields: List, + nested: bool = False, + use_alias: bool = False, + ) -> List[str]: + column_names = [ + model.get_column_name_from_alias(col.name) if use_alias else col.name + for col in model.Meta.table.columns + ] if not fields: return column_names if not nested: - columns = [ - name for name in fields if "__" not in name and name in column_names - ] + columns = ModelTableProxy._get_not_nested_columns_from_fields( + model, fields, column_names, use_alias + ) else: - model_name = f"{model.get_name()}__" - columns = [ - name[(name.find(model_name) + len(model_name)) :] # noqa: E203 - for name in fields - if f"{model.get_name()}__" in name - ] + columns = ModelTableProxy._get_nested_columns_from_fields( + model, fields, use_alias + ) # if the model is in select and no columns in fields, all implied if not columns: columns = column_names # always has to return pk column - if model.Meta.pkname not in columns: - columns.append(model.Meta.pkname) + pk_alias = ( + model.get_column_alias(model.Meta.pkname) + if not use_alias + else model.Meta.pkname + ) + if pk_alias not in columns: + columns.append(pk_alias) return columns diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 140d118..88664fb 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -134,8 +134,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def _extract_related_model_instead_of_field( self, item: str ) -> Optional[Union["Model", List["Model"]]]: - if item in self._orm: - return self._orm.get(item) + alias = self.get_column_alias(item) + if alias in self._orm: + return self._orm.get(alias) return None def __eq__(self, other: object) -> bool: diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 6b9d5a4..362ba85 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -44,7 +44,7 @@ class QueryClause: ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: if kwargs.get("pk"): - pk_name = self.model_cls.Meta.pkname + pk_name = self.model_cls.get_column_alias(self.model_cls.Meta.pkname) kwargs[pk_name] = kwargs.pop("pk") filter_clauses, select_related = self._populate_filter_clauses(**kwargs) @@ -83,7 +83,7 @@ class QueryClause: else: op = "exact" - column = self.table.columns[key] + column = self.table.columns[self.model_cls.get_column_alias(key)] table = self.table clause = self._process_column_clause_for_operator_and_value( diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index f79f353..c045e16 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -106,9 +106,11 @@ class SqlJoin: self.select_from = sqlalchemy.sql.outerjoin( self.select_from, target_table, on_clause ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) + + pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname) + self.order_bys.append(text(f"{alias}_{to_table}.{pkname_alias}")) self_related_fields = model_cls.own_table_columns( - model_cls, self.fields, nested=True + model_cls, self.fields, nested=True, ) self.columns.extend( self.relation_manager(model_cls).prefixed_columns( @@ -125,12 +127,13 @@ class SqlJoin: part: str, ) -> Tuple[str, str]: if join_params.prev_model.Meta.model_fields[part].virtual or is_multi: - to_field = model_cls.resolve_relation_field( + to_field = model_cls.resolve_relation_name( model_cls, join_params.prev_model ) - to_key = to_field.name - from_key = model_cls.Meta.pkname + to_key = model_cls.get_column_alias(to_field) + from_key = join_params.prev_model.get_column_alias(model_cls.Meta.pkname) else: - to_key = model_cls.Meta.pkname - from_key = part + to_key = model_cls.get_column_alias(model_cls.Meta.pkname) + from_key = join_params.prev_model.get_column_alias(part) + return to_key, from_key diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index ada3437..880b3c1 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -40,7 +40,8 @@ class Query: @property def prefixed_pk_name(self) -> str: - return f"{self.table.name}.{self.model_cls.Meta.pkname}" + pkname_alias = self.model_cls.get_column_alias(self.model_cls.Meta.pkname) + return f"{self.table.name}.{pkname_alias}" def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self_related_fields = self.model_cls.own_table_columns( diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index adece00..df1bc01 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -70,12 +70,35 @@ class QuerySet: return self.model.merge_instances_list(result_rows) # type: ignore return result_rows + def _prepare_model_to_save(self, new_kwargs: dict) -> dict: + new_kwargs = self._remove_pk_from_kwargs(new_kwargs) + new_kwargs = self.model.substitute_models_with_pks(new_kwargs) + new_kwargs = self._populate_default_values(new_kwargs) + new_kwargs = self._translate_columns_to_aliases(new_kwargs) + return new_kwargs + def _populate_default_values(self, new_kwargs: dict) -> dict: for field_name, field in self.model_meta.model_fields.items(): if field_name not in new_kwargs and field.has_default(): new_kwargs[field_name] = field.get_default() return new_kwargs + def _translate_columns_to_aliases(self, new_kwargs: dict) -> dict: + for field_name, field in self.model_meta.model_fields.items(): + if ( + field_name in new_kwargs + and field.name is not None + and field.name != field_name + ): + new_kwargs[field.name] = new_kwargs.pop(field_name) + return new_kwargs + + def _translate_aliases_to_columns(self, new_kwargs: dict) -> dict: + for field_name, field in self.model_meta.model_fields.items(): + if field.name in new_kwargs and field.name != field_name: + new_kwargs[field_name] = new_kwargs.pop(field.name) + return new_kwargs + def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: pkname = self.model_meta.pkname pk = self.model_meta.model_fields[pkname] @@ -184,6 +207,7 @@ class QuerySet: async def update(self, each: bool = False, **kwargs: Any) -> int: self_fields = self.model.extract_db_own_fields() updates = {k: v for k, v in kwargs.items() if k in self_fields} + updates = self._translate_columns_to_aliases(updates) if not each and not self.filter_clauses: raise QueryDefinitionError( "You cannot update without filtering the queryset first. " @@ -278,9 +302,7 @@ class QuerySet: async def create(self, **kwargs: Any) -> "Model": new_kwargs = dict(**kwargs) - new_kwargs = self._remove_pk_from_kwargs(new_kwargs) - new_kwargs = self.model.substitute_models_with_pks(new_kwargs) - new_kwargs = self._populate_default_values(new_kwargs) + new_kwargs = self._prepare_model_to_save(new_kwargs) expr = self.table.insert() expr = expr.values(**new_kwargs) @@ -288,7 +310,7 @@ class QuerySet: instance = self.model(**kwargs) pk = await self.database.execute(expr) - pk_name = self.model_meta.pkname + pk_name = self.model.get_column_alias(self.model_meta.pkname) if pk_name not in kwargs and pk_name in new_kwargs: instance.pk = new_kwargs[self.model_meta.pkname] if pk and isinstance(pk, self.model.pk_type()): @@ -300,9 +322,7 @@ class QuerySet: ready_objects = [] for objt in objects: new_kwargs = objt.dict() - new_kwargs = self._remove_pk_from_kwargs(new_kwargs) - new_kwargs = self.model.substitute_models_with_pks(new_kwargs) - new_kwargs = self._populate_default_values(new_kwargs) + new_kwargs = self._prepare_model_to_save(new_kwargs) ready_objects.append(new_kwargs) expr = self.table.insert() @@ -323,6 +343,8 @@ class QuerySet: if pk_name not in columns: columns.append(pk_name) + columns = [self.model.get_column_alias(k) for k in columns] + for objt in objects: new_kwargs = objt.dict() if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: @@ -331,15 +353,24 @@ class QuerySet: f"{self.model.__name__} has to have {pk_name} filled." ) new_kwargs = self.model.substitute_models_with_pks(new_kwargs) + new_kwargs = self._translate_columns_to_aliases(new_kwargs) new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} ready_objects.append(new_kwargs) - pk_column = self.model_meta.table.c.get(pk_name) - expr = self.table.update().where(pk_column == bindparam("new_" + pk_name)) + pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name)) + pk_column_name = self.model.get_column_alias(pk_name) + table_columns = [c.name for c in self.model_meta.table.c] + expr = self.table.update().where( + pk_column == bindparam("new_" + pk_column_name) + ) expr = expr.values( - **{k: bindparam("new_" + k) for k in columns if k != pk_name} + **{ + k: bindparam("new_" + k) + for k in columns + if k != pk_column_name and k in table_columns + } ) # databases bind params only where query is passed as string - # otherwise it just pases all data to values and results in unconsumed columns + # otherwise it just passes all data to values and results in unconsumed columns expr = str(expr) await self.database.execute_many(expr, ready_objects) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 3863679..88130d5 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -39,7 +39,7 @@ class RelationProxy(list): def _set_queryset(self) -> "QuerySet": owner_table = self.relation._owner.Meta.tablename - pkname = self.relation._owner.Meta.pkname + pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname) pk_value = self.relation._owner.pk if not pk_value: raise RelationshipInstanceError( diff --git a/tests/test_aliases.py b/tests/test_aliases.py new file mode 100644 index 0000000..f169c30 --- /dev/null +++ b/tests/test_aliases.py @@ -0,0 +1,157 @@ +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Child(ormar.Model): + class Meta: + tablename = "children" + metadata = metadata + database = database + + id: ormar.Integer(name='child_id', primary_key=True) + first_name: ormar.String(name='fname', max_length=100) + last_name: ormar.String(name='lname', max_length=100) + born_year: ormar.Integer(name='year_born', nullable=True) + + +class ArtistChildren(ormar.Model): + class Meta: + tablename = "children_x_artists" + metadata = metadata + database = database + + +class Artist(ormar.Model): + class Meta: + tablename = "artists" + metadata = metadata + database = database + + id: ormar.Integer(name='artist_id', primary_key=True) + first_name: ormar.String(name='fname', max_length=100) + last_name: ormar.String(name='lname', max_length=100) + born_year: ormar.Integer(name='year') + children: ormar.ManyToMany(Child, through=ArtistChildren) + + +class Album(ormar.Model): + class Meta: + tablename = "music_albums" + metadata = metadata + database = database + + id: ormar.Integer(name='album_id', primary_key=True) + name: ormar.String(name='album_name', max_length=100) + artist: ormar.ForeignKey(Artist, name='artist_id') + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_table_structure(): + assert 'album_id' in [x.name for x in Album.Meta.table.columns] + assert 'album_name' in [x.name for x in Album.Meta.table.columns] + assert 'fname' in [x.name for x in Artist.Meta.table.columns] + assert 'lname' in [x.name for x in Artist.Meta.table.columns] + assert 'year' in [x.name for x in Artist.Meta.table.columns] + + +@pytest.mark.asyncio +async def test_working_with_aliases(): + async with database: + async with database.transaction(force_rollback=True): + artist = await Artist.objects.create(first_name='Ted', last_name='Mosbey', born_year=1975) + await Album.objects.create(name="Aunt Robin", artist=artist) + + await artist.children.create(first_name='Son', last_name='1', born_year=1990) + await artist.children.create(first_name='Son', last_name='2', born_year=1995) + + album = await Album.objects.select_related('artist').first() + assert album.artist.last_name == 'Mosbey' + + assert album.artist.id is not None + assert album.artist.first_name == 'Ted' + assert album.artist.born_year == 1975 + + assert album.name == 'Aunt Robin' + + artist = await Artist.objects.select_related('children').get() + assert len(artist.children) == 2 + assert artist.children[0].first_name == 'Son' + assert artist.children[1].last_name == '2' + + await artist.update(last_name='Bundy') + await Artist.objects.filter(pk=artist.pk).update(born_year=1974) + + artist = await Artist.objects.select_related('children').get() + assert artist.last_name == 'Bundy' + assert artist.born_year == 1974 + + artist = await Artist.objects.select_related('children').fields( + ['first_name', 'last_name', 'born_year', 'child__first_name', 'child__last_name']).get() + assert artist.children[0].born_year is None + + +@pytest.mark.asyncio +async def test_bulk_operations_and_fields(): + async with database: + d1 = Child(first_name='Daughter', last_name='1', born_year=1990) + d2 = Child(first_name='Daughter', last_name='2', born_year=1991) + await Child.objects.bulk_create([d1, d2]) + + children = await Child.objects.filter(first_name='Daughter').all() + assert len(children) == 2 + assert children[0].last_name == '1' + + for child in children: + child.born_year = child.born_year - 100 + + await Child.objects.bulk_update(children) + + children = await Child.objects.filter(first_name='Daughter').all() + assert len(children) == 2 + assert children[0].born_year == 1890 + + children = await Child.objects.fields(['first_name', 'last_name']).all() + assert len(children) == 2 + for child in children: + assert child.born_year is None + + await children[0].load() + await children[0].delete() + children = await Child.objects.all() + assert len(children) == 1 + + +@pytest.mark.asyncio +async def test_working_with_aliases_get_or_create(): + async with database: + async with database.transaction(force_rollback=True): + artist = await Artist.objects.get_or_create(first_name='Teddy', last_name='Bear', born_year=2020) + assert artist.pk is not None + + artist2 = await Artist.objects.get_or_create(first_name='Teddy', last_name='Bear', born_year=2020) + assert artist == artist2 + + art3 = artist2.dict() + art3['born_year'] = 2019 + await Artist.objects.update_or_create(**art3) + + artist3 = await Artist.objects.get(last_name='Bear') + assert artist3.born_year == 2019 + + artists = await Artist.objects.all() + assert len(artists) == 1