Merge pull request #17 from collerek/aliases

Provide name parameter for specifying alternative name for column in the database
This commit is contained in:
collerek
2020-10-22 19:11:38 +07:00
committed by GitHub
27 changed files with 713 additions and 223 deletions

BIN
.coverage

Binary file not shown.

View File

@ -97,7 +97,9 @@ await Track.objects.create(album=malibu, title="The Bird", position=1)
await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2)
await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies")
# alternative creation of object divided into 2 steps
fantasies = Album.objects.create(name="Fantasies")
await fantasies.save()
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
@ -137,12 +139,33 @@ tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1
```
## Data types
## Ormar Specification
### QuerySet methods
* `create(**kwargs): -> Model`
* `get(**kwargs): -> Model`
* `get_or_create(**kwargs) -> Model`
* `update(each: bool = False, **kwargs) -> int`
* `update_or_create(**kwargs) -> Model`
* `bulk_create(objects: List[Model]) -> None`
* `bulk_update(objects: List[Model], columns: List[str] = None) -> None`
* `delete(each: bool = False, **kwargs) -> int`
* `all(self, **kwargs) -> List[Optional[Model]]`
* `filter(**kwargs) -> QuerySet`
* `exclude(**kwargs) -> QuerySet`
* `select_related(related: Union[List, str]) -> QuerySet`
* `limit(limit_count: int) -> QuerySet`
* `offset(offset: int) -> QuerySet`
* `count() -> int`
* `exists() -> bool`
* `fields(columns: Union[List, str]) -> QuerySet`
#### Relation types
* One to many - with `ForeignKey`
* Many to many - with `Many2Many`
* One to many - with `ForeignKey(to: Model)`
* Many to many - with `ManyToMany(to: Model, through: Model)`
#### Model fields types
@ -161,7 +184,7 @@ Available Model Fields (with required args - optional ones in docs):
* `Decimal(scale, precision)`
* `UUID()`
* `ForeignKey(to)`
* `Many2Many(to, through)`
* `ManyToMany(to, through)`
### Available fields options
The following keyword arguments are supported on all field types.

View File

@ -6,97 +6,66 @@ you need to do is substitute pydantic models with ormar models.
Here you can find a very simple sample application code.
## Imports and initialization
First take care of the imports and initialization
```python hl_lines="1-12"
--8<-- "../docs_src/fastapi/docs001.py"
```
## Database connection
Next define startup and shutdown events (or use middleware)
- note that this is `databases` specific setting not the ormar one
```python hl_lines="15-26"
--8<-- "../docs_src/fastapi/docs001.py"
```
!!!info
You can read more on connecting to databases in [fastapi][fastapi] documentation
## Models definition
Define ormar models with appropriate fields.
Those models will be used insted of pydantic ones.
```python hl_lines="29-47"
--8<-- "../docs_src/fastapi/docs001.py"
```
!!!tip
You can read more on defining `Models` in [models][models] section.
## Fastapi endpoints definition
Define your desired endpoints, note how `ormar` models are used both
as `response_model` and as a requests parameters.
```python hl_lines="50-77"
--8<-- "../docs_src/fastapi/docs001.py"
```
!!!note
Note how ormar `Model` methods like save() are available straight out of the box after fastapi initializes it for you.
!!!note
Note that you can return a `Model` (or list of `Models`) directly - fastapi will jsonize it for you
## Test the application
Here you have a sample test that will prove that everything works as intended.
```python
from typing import List
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database(DATABASE_URL, force_rollback=True)
app.state.database = database
# define startup and shutdown events
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
# define ormar models
class Category(ormar.Model):
class Meta:
tablename = "categories"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
class Item(ormar.Model):
class Meta:
tablename = "items"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
category: ormar.ForeignKey(Category, nullable=True)
# define endpoints in fastapi
@app.get("/items/", response_model=List[Item])
async def get_items():
items = await Item.objects.select_related("category").all()
# not that you can return a model directly - fastapi will json-ize it
return items
@app.post("/items/", response_model=Item)
async def create_item(item: Item):
# note how ormar methods like save() are available streight out of the box
await item.save()
return item
@app.post("/categories/", response_model=Category)
async def create_category(category: Category):
await category.save()
return category
@app.put("/items/{item_id}")
async def get_item(item_id: int, item: Item):
# you can work both with item_id or item
item_db = await Item.objects.get(pk=item_id)
return await item_db.update(**item.dict())
@app.delete("/items/{item_id}")
async def delete_item(item_id: int, item: Item):
item_db = await Item.objects.get(pk=item_id)
return {"deleted_rows": await item_db.delete()}
# here is a sample test to check the working of the ormar with fastapi
from starlette.testclient import TestClient
def test_all_endpoints():
# note that TestClient is only sync, don't use asyns here
client = TestClient(app)
# note that you need to connect to database manually
# or use client as contextmanager
# or use client as contextmanager during tests
with client as client:
response = client.post("/categories/", json={"name": "test cat"})
category = response.json()
@ -124,3 +93,9 @@ def test_all_endpoints():
items = response.json()
assert len(items) == 0
```
!!!info
You can read more on testing fastapi in [fastapi][fastapi] docs.
[fastapi]: https://fastapi.tiangolo.com/
[models]: ./models.md

View File

@ -97,7 +97,9 @@ await Track.objects.create(album=malibu, title="The Bird", position=1)
await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2)
await Track.objects.create(album=malibu, title="The Waters", position=3)
fantasies = await Album.objects.create(name="Fantasies")
# alternative creation of object divided into 2 steps
fantasies = Album.objects.create(name="Fantasies")
await fantasies.save()
await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1)
await Track.objects.create(album=fantasies, title="Sick Muse", position=2)
@ -137,12 +139,33 @@ tracks = await Track.objects.limit(1).all()
assert len(tracks) == 1
```
## Data types
## Ormar Specification
### QuerySet methods
* `create(**kwargs): -> Model`
* `get(**kwargs): -> Model`
* `get_or_create(**kwargs) -> Model`
* `update(each: bool = False, **kwargs) -> int`
* `update_or_create(**kwargs) -> Model`
* `bulk_create(objects: List[Model]) -> None`
* `bulk_update(objects: List[Model], columns: List[str] = None) -> None`
* `delete(each: bool = False, **kwargs) -> int`
* `all(self, **kwargs) -> List[Optional[Model]]`
* `filter(**kwargs) -> QuerySet`
* `exclude(**kwargs) -> QuerySet`
* `select_related(related: Union[List, str]) -> QuerySet`
* `limit(limit_count: int) -> QuerySet`
* `offset(offset: int) -> QuerySet`
* `count() -> int`
* `exists() -> bool`
* `fields(columns: Union[List, str]) -> QuerySet`
#### Relation types
* One to many - with `ForeignKey`
* Many to many - with `Many2Many`
* One to many - with `ForeignKey(to: Model)`
* Many to many - with `ManyToMany(to: Model, through: Model)`
#### Model fields types
@ -161,7 +184,7 @@ Available Model Fields (with required args - optional ones in docs):
* `Decimal(scale, precision)`
* `UUID()`
* `ForeignKey(to)`
* `Many2Many(to, through)`
* `ManyToMany(to, through)`
### Available fields options
The following keyword arguments are supported on all field types.
@ -173,6 +196,7 @@ The following keyword arguments are supported on all field types.
* `index: bool`
* `unique: bool`
* `choices: typing.Sequence`
* `name: str`
All fields are required unless one of the following is set:

View File

@ -37,7 +37,27 @@ You can disable by passing `autoincremant=False`.
id: ormar.Integer(primary_key=True, autoincrement=False)
```
Names of the fields will be used for both the underlying `pydantic` model and `sqlalchemy` table.
### Fields names vs Column names
By default names of the fields will be used for both the underlying `pydantic` model and `sqlalchemy` table.
If for whatever reason you prefer to change the name in the database but keep the name in the model you can do this
with specifying `name` parameter during Field declaration
Here you have a sample model with changed names
```Python hl_lines="16-19"
--8<-- "../docs_src/models/docs008.py"
```
Note that you can also change the ForeignKey column name
```Python hl_lines="9"
--8<-- "../docs_src/models/docs009.py"
```
But for now you cannot change the ManyToMany column names as they go through other Model anyway.
```Python hl_lines="18"
--8<-- "../docs_src/models/docs010.py"
```
### Dependencies
@ -128,7 +148,9 @@ Each model has a `QuerySet` initialised as `objects` parameter
### load
By default when you query a table without prefetching related models, the ormar will still construct
your related models, but populate them only with the pk value.
your related models, but populate them only with the pk value. You can load the related model by calling `load()` method.
`load()` can also be used to refresh the model from the database (if it was changed by some other process).
```python
track = await Track.objects.get(name='The Bird')
@ -142,10 +164,36 @@ track.album.name # will return 'Malibu'
### save
You can create new models by using `QuerySet.create()` method or by initializing your model as a normal pydantic model
and later calling `save()` method.
`save()` can also be used to persist changes that you made to the model.
```python
track = Track(name='The Bird')
await track.save() # will persist the model in database
```
### delete
You can delete models by using `QuerySet.delete()` method or by using your model and calling `delete()` method.
```python
track = await Track.objects.get(name='The Bird')
await track.delete() # will delete the model from database
```
!!!tip
Note that that `track` object stays the same, only record in the database is removed.
### update
You can delete models by using `QuerySet.update()` method or by using your model and calling `update()` method.
```python
track = await Track.objects.get(name='The Bird')
await track.update(name='The Bird Strikes Again')
```
## Internals

View File

@ -85,9 +85,9 @@ Finally you can explicitly set it to None (default behavior if no value passed).
Otherwise an IntegrityError will be raised by your database driver library.
### Many2Many
### ManyToMany
`Many2Many(to, through)` has required parameters `to` and `through` that takes target and relation `Model` classes.
`ManyToMany(to, through)` has required parameters `to` and `through` that takes target and relation `Model` classes.
Sqlalchemy column and Type are automatically taken from target `Model`.
@ -131,7 +131,7 @@ assert len(await post.categories.all()) == 2
```
!!!note
Note that when accessing QuerySet API methods through Many2Many relation you don't
Note that when accessing QuerySet API methods through ManyToMany relation you don't
need to use objects attribute like in normal queries.
To learn more about available QuerySet methods visit [queries][queries]
@ -146,7 +146,7 @@ await news.posts.clear()
#### All other queryset methods
When access directly the related `Many2Many` field returns the list of related models.
When access directly the related `ManyToMany` field returns the list of related models.
But at the same time it exposes full QuerySet API, so you can filter, create, select related etc.

45
docs/releases.md Normal file
View File

@ -0,0 +1,45 @@
# 0.3.8
* Added possibility to provide alternative database column names with name parameter to all fields.
* Fix bug with selecting related ManyToMany fields with `fields()` if they are empty.
* Updated documentation
# 0.3.7
* Publish documentation and update readme
# 0.3.6
* Add fields() method to limit the selected columns from database - only nullable columns can be excluded.
* Added UniqueColumns and constraints list in model Meta to build unique constraints on list of columns.
* Added UUID field type based on Char(32) column type.
# 0.3.5
* Added bulk_create and bulk_update for operations on multiple objects.
# 0.3.4
Add queryset level methods
* delete
* update
* get_or_create
* update_or_create
# 0.3.3
* Add additional filters - startswith and endswith
# 0.3.2
* Add choices parameter to all fields - limiting the accepted values to ones provided
# 0.3.1
* Added exclude to filter where not conditions.
* Added tests for mysql and postgres with fixes for postgres.
* Rafactors and cleanup.
# 0.3.0
* Added ManyToMany field and support for many to many relations

View File

View File

@ -0,0 +1,77 @@
from typing import List
import databases
import sqlalchemy
from fastapi import FastAPI
import ormar
app = FastAPI()
metadata = sqlalchemy.MetaData()
database = databases.Database("sqlite:///test.db", force_rollback=True)
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class Category(ormar.Model):
class Meta:
tablename = "categories"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
class Item(ormar.Model):
class Meta:
tablename = "items"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
category: ormar.ForeignKey(Category, nullable=True)
@app.get("/items/", response_model=List[Item])
async def get_items():
items = await Item.objects.select_related("category").all()
return items
@app.post("/items/", response_model=Item)
async def create_item(item: Item):
await item.save()
return item
@app.post("/categories/", response_model=Category)
async def create_category(category: Category):
await category.save()
return category
@app.put("/items/{item_id}")
async def get_item(item_id: int, item: Item):
item_db = await Item.objects.get(pk=item_id)
return await item_db.update(**item.dict())
@app.delete("/items/{item_id}")
async def delete_item(item_id: int, item: Item):
item_db = await Item.objects.get(pk=item_id)
return {"deleted_rows": await item_db.delete()}

View File

@ -0,0 +1,19 @@
import databases
import sqlalchemy
import ormar
database = databases.Database("sqlite:///test.db", force_rollback=True)
metadata = sqlalchemy.MetaData()
class Child(ormar.Model):
class Meta:
tablename = "children"
metadata = metadata
database = database
id: ormar.Integer(name='child_id', primary_key=True)
first_name: ormar.String(name='fname', max_length=100)
last_name: ormar.String(name='lname', max_length=100)
born_year: ormar.Integer(name='year_born', nullable=True)

View File

@ -0,0 +1,9 @@
class Album(ormar.Model):
class Meta:
tablename = "music_albums"
metadata = metadata
database = database
id: ormar.Integer(name='album_id', primary_key=True)
name: ormar.String(name='album_name', max_length=100)
artist: ormar.ForeignKey(Artist, name='artist_id')

View File

@ -0,0 +1,18 @@
class ArtistChildren(ormar.Model):
class Meta:
tablename = "children_x_artists"
metadata = metadata
database = database
class Artist(ormar.Model):
class Meta:
tablename = "artists"
metadata = metadata
database = database
id: ormar.Integer(name='artist_id', primary_key=True)
first_name: ormar.String(name='fname', max_length=100)
last_name: ormar.String(name='lname', max_length=100)
born_year: ormar.Integer(name='year')
children: ormar.ManyToMany(Child, through=ArtistChildren)

View File

@ -9,6 +9,7 @@ nav:
- Queries: queries.md
- Use with Fastapi: fastapi.md
- Contributing: contributing.md
- Release Notes: releases.md
repo_name: collerek/ormar
repo_url: https://github.com/collerek/ormar
google_analytics:

View File

@ -28,7 +28,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType()
__version__ = "0.3.7"
__version__ = "0.3.8"
__all__ = [
"Integer",
"BigInteger",

View File

@ -64,7 +64,7 @@ class BaseField:
@classmethod
def get_column(cls, name: str) -> sqlalchemy.Column:
return sqlalchemy.Column(
name,
cls.name or name,
cls.column_type,
*cls.constraints,
primary_key=cls.primary_key,

View File

@ -38,7 +38,7 @@ def ForeignKey( # noqa CFQ002
onupdate: str = None,
ondelete: str = None,
) -> Type["ForeignKeyField"]:
fk_string = to.Meta.tablename + "." + to.Meta.pkname
fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname)
to_field = to.__fields__[to.Meta.pkname]
namespace = dict(
to=to,

View File

@ -12,7 +12,7 @@ from ormar.fields.base import BaseField # noqa I101
def is_field_nullable(
nullable: Optional[bool], default: Any, server_default: Any
nullable: Optional[bool], default: Any, server_default: Any
) -> bool:
if nullable is None:
return default is not None or server_default is not None
@ -61,15 +61,15 @@ class String(ModelFieldFactory):
_type = str
def __new__( # type: ignore # noqa CFQ002
cls,
*,
allow_blank: bool = True,
strip_whitespace: bool = False,
min_length: int = None,
max_length: int = None,
curtail_length: int = None,
regex: str = None,
**kwargs: Any
cls,
*,
allow_blank: bool = True,
strip_whitespace: bool = False,
min_length: int = None,
max_length: int = None,
curtail_length: int = None,
regex: str = None,
**kwargs: Any
) -> Type[BaseField]: # type: ignore
kwargs = {
**kwargs,
@ -79,7 +79,7 @@ class String(ModelFieldFactory):
if k not in ["cls", "__class__", "kwargs"]
},
}
kwargs['allow_blank'] = kwargs.get('nullable', True)
kwargs["allow_blank"] = kwargs.get("nullable", True)
return super().__new__(cls, **kwargs)
@classmethod
@ -100,12 +100,12 @@ class Integer(ModelFieldFactory):
_type = int
def __new__( # type: ignore
cls,
*,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
**kwargs: Any
cls,
*,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[BaseField]:
autoincrement = kwargs.pop("autoincrement", None)
autoincrement = (
@ -135,7 +135,7 @@ class Text(ModelFieldFactory):
_type = str
def __new__( # type: ignore
cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any
cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any
) -> Type[BaseField]:
kwargs = {
**kwargs,
@ -145,7 +145,7 @@ class Text(ModelFieldFactory):
if k not in ["cls", "__class__", "kwargs"]
},
}
kwargs['allow_blank'] = kwargs.get('nullable', True)
kwargs["allow_blank"] = kwargs.get("nullable", True)
return super().__new__(cls, **kwargs)
@classmethod
@ -158,12 +158,12 @@ class Float(ModelFieldFactory):
_type = float
def __new__( # type: ignore
cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
**kwargs: Any
cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[BaseField]:
kwargs = {
**kwargs,
@ -232,12 +232,12 @@ class BigInteger(Integer):
_type = int
def __new__( # type: ignore
cls,
*,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
**kwargs: Any
cls,
*,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
**kwargs: Any
) -> Type[BaseField]:
autoincrement = kwargs.pop("autoincrement", None)
autoincrement = (
@ -267,16 +267,16 @@ class Decimal(ModelFieldFactory):
_type = decimal.Decimal
def __new__( # type: ignore # noqa CFQ002
cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
precision: int = None,
scale: int = None,
max_digits: int = None,
decimal_places: int = None,
**kwargs: Any
cls,
*,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
precision: int = None,
scale: int = None,
max_digits: int = None,
decimal_places: int = None,
**kwargs: Any
) -> Type[BaseField]:
kwargs = {
**kwargs,

View File

@ -41,7 +41,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) ->
def register_many_to_many_relation_on_build(
table_name: str, field: Type[ManyToManyField]
table_name: str, field: Type[ManyToManyField]
) -> None:
alias_manager.add_relation_type(field.through.Meta.tablename, table_name)
alias_manager.add_relation_type(
@ -50,11 +50,11 @@ def register_many_to_many_relation_on_build(
def reverse_field_not_already_registered(
child: Type["Model"], child_model_name: str, parent_model: Type["Model"]
child: Type["Model"], child_model_name: str, parent_model: Type["Model"]
) -> bool:
return (
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
)
@ -65,7 +65,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
parent_model = model_field.to
child = model
if reverse_field_not_already_registered(
child, child_model_name, parent_model
child, child_model_name, parent_model
):
register_reverse_model_fields(
parent_model, child, child_model_name, model_field
@ -73,10 +73,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
def register_reverse_model_fields(
model: Type["Model"],
child: Type["Model"],
child_model_name: str,
model_field: Type["ForeignKeyField"],
model: Type["Model"],
child: Type["Model"],
child_model_name: str,
model_field: Type["ForeignKeyField"],
) -> None:
if issubclass(model_field, ManyToManyField):
model.Meta.model_fields[child_model_name] = ManyToMany(
@ -91,7 +91,7 @@ def register_reverse_model_fields(
def adjust_through_many_to_many_model(
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
model, name=model.get_name(), ondelete="CASCADE"
@ -108,7 +108,7 @@ def adjust_through_many_to_many_model(
def create_pydantic_field(
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField]
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
model_field.through.__fields__[field_name] = ModelField(
name=field_name,
@ -120,13 +120,13 @@ def create_pydantic_field(
def create_and_append_m2m_fk(
model: Type["Model"], model_field: Type[ManyToManyField]
model: Type["Model"], model_field: Type[ManyToManyField]
) -> None:
column = sqlalchemy.Column(
model.get_name(),
model.Meta.table.columns.get(model.Meta.pkname).type,
model.Meta.table.columns.get(model.get_column_alias(model.Meta.pkname)).type,
sqlalchemy.schema.ForeignKey(
model.Meta.tablename + "." + model.Meta.pkname,
model.Meta.tablename + "." + model.get_column_alias(model.Meta.pkname),
ondelete="CASCADE",
onupdate="CASCADE",
),
@ -136,7 +136,7 @@ def create_and_append_m2m_fk(
def check_pk_column_validity(
field_name: str, field: BaseField, pkname: Optional[str]
field_name: str, field: BaseField, pkname: Optional[str]
) -> Optional[str]:
if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.")
@ -146,7 +146,7 @@ def check_pk_column_validity(
def sqlalchemy_columns_from_model_fields(
model_fields: Dict, table_name: str
model_fields: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
columns = []
pkname = None
@ -160,9 +160,9 @@ def sqlalchemy_columns_from_model_fields(
if field.primary_key:
pkname = check_pk_column_validity(field_name, field, pkname)
if (
not field.pydantic_only
and not field.virtual
and not issubclass(field, ManyToManyField)
not field.pydantic_only
and not field.virtual
and not issubclass(field, ManyToManyField)
):
columns.append(field.get_column(field_name))
register_relation_in_alias_manager(table_name, field)
@ -170,7 +170,7 @@ def sqlalchemy_columns_from_model_fields(
def register_relation_in_alias_manager(
table_name: str, field: Type[ForeignKeyField]
table_name: str, field: Type[ForeignKeyField]
) -> None:
if issubclass(field, ManyToManyField):
register_many_to_many_relation_on_build(table_name, field)
@ -179,7 +179,7 @@ def register_relation_in_alias_manager(
def populate_default_pydantic_field_value(
type_: Type[BaseField], field: str, attrs: dict
type_: Type[BaseField], field: str, attrs: dict
) -> dict:
def_value = type_.default_value()
curr_def_value = attrs.get(field, "NONE")
@ -208,7 +208,7 @@ def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict:
def populate_meta_orm_model_fields(
attrs: dict, new_model: Type["Model"]
attrs: dict, new_model: Type["Model"]
) -> Type["Model"]:
model_fields = {
field_name: field
@ -220,10 +220,12 @@ def populate_meta_orm_model_fields(
def populate_meta_tablename_columns_and_pk(
name: str, new_model: Type["Model"]
name: str, new_model: Type["Model"]
) -> Type["Model"]:
tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename if hasattr(new_model.Meta, 'tablename') else tablename
new_model.Meta.tablename = (
new_model.Meta.tablename if hasattr(new_model.Meta, "tablename") else tablename
)
pkname: Optional[str]
if hasattr(new_model.Meta, "columns"):
@ -244,7 +246,7 @@ def populate_meta_tablename_columns_and_pk(
def populate_meta_sqlalchemy_table_if_required(
new_model: Type["Model"],
new_model: Type["Model"],
) -> Type["Model"]:
if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table(
@ -286,7 +288,7 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A
def populate_choices_validators( # noqa CCR001
model: Type["Model"], attrs: Dict
model: Type["Model"], attrs: Dict
) -> None:
if model_initialized_and_has_model_fields(model):
for _, field in model.Meta.model_fields.items():
@ -299,7 +301,7 @@ def populate_choices_validators( # noqa CCR001
class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__( # type: ignore
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
) -> "ModelMetaclass":
attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name

View File

@ -43,7 +43,6 @@ class Model(NewBaseModel):
if select_related:
related_models = group_related_list(select_related)
# breakpoint()
if (
previous_table
and previous_table in cls.Meta.model_fields
@ -90,13 +89,13 @@ class Model(NewBaseModel):
previous_table=previous_table,
fields=fields,
)
item[first_part] = child
item[model_cls.get_column_name_from_alias(first_part)] = child
else:
model_cls = cls.Meta.model_fields[related].to
child = model_cls.from_row(
row, previous_table=previous_table, fields=fields
)
item[related] = child
item[model_cls.get_column_name_from_alias(related)] = child
return item
@ -113,13 +112,16 @@ class Model(NewBaseModel):
# databases does not keep aliases in Record for postgres, change to raw row
source = row._row if isinstance(row, Record) else row
selected_columns = cls.own_table_columns(cls, fields or [], nested=nested)
selected_columns = cls.own_table_columns(
cls, fields or [], nested=nested, use_alias=True
)
for column in cls.Meta.table.columns:
if column.name not in item and column.name in selected_columns:
alias = cls.get_column_name_from_alias(column.name)
if alias not in item and alias in selected_columns:
prefixed_name = (
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
)
item[column.name] = source[prefixed_name]
item[alias] = source[prefixed_name]
return item
@ -142,7 +144,8 @@ class Model(NewBaseModel):
self.from_dict(new_values)
self_fields = self._extract_model_db_fields()
self_fields.pop(self.Meta.pkname)
self_fields.pop(self.get_column_name_from_alias(self.Meta.pkname))
self_fields = self.objects._translate_columns_to_aliases(self_fields)
expr = self.Meta.table.update().values(**self_fields)
expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname))
@ -162,5 +165,7 @@ class Model(NewBaseModel):
raise ValueError(
"Instance was deleted from database and cannot be refreshed"
)
self.from_dict(dict(row))
kwargs = dict(row)
kwargs = self.objects._translate_aliases_to_columns(kwargs)
self.from_dict(kwargs)
return self

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Set, TYPE_CHECKING, Type, TypeVar, Union
import ormar
from ormar.exceptions import RelationshipInstanceError
from ormar.fields import BaseField
from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta
@ -35,7 +35,7 @@ class ModelTableProxy:
return self_fields
@classmethod
def substitute_models_with_pks(cls, model_dict: Dict) -> Dict:
def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: # noqa CCR001
for field in cls.extract_related_names():
field_value = model_dict.get(field, None)
if field_value is not None:
@ -43,10 +43,26 @@ class ModelTableProxy:
target_pkname = target_field.to.Meta.pkname
if isinstance(field_value, ormar.Model):
model_dict[field] = getattr(field_value, target_pkname)
else:
elif field_value:
model_dict[field] = field_value.get(target_pkname)
else:
model_dict.pop(field, None)
return model_dict
@classmethod
def get_column_alias(cls, field_name: str) -> str:
field = cls.Meta.model_fields.get(field_name)
if field and field.name is not None and field.name != field_name:
return field.name
return field_name
@classmethod
def get_column_name_from_alias(cls, alias: str) -> str:
for field_name, field in cls.Meta.model_fields.items():
if field and field.name == alias:
return field_name
return alias # if not found it's not an alias but actual name
@classmethod
def extract_related_names(cls) -> Set:
related_names = set()
@ -62,6 +78,7 @@ class ModelTableProxy:
if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and not issubclass(field, ManyToManyField)
and not field.virtual
):
related_names.add(name)
@ -84,7 +101,9 @@ class ModelTableProxy:
def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields()
self_fields = {
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
k: v
for k, v in self_fields.items()
if self.get_column_alias(k) in self.Meta.table.columns
}
for field in self._extract_db_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
@ -125,7 +144,7 @@ class ModelTableProxy:
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows: List["Model"] = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == merged_rows[-1].pk:
if index > 0 and model is not None and model.pk == merged_rows[-1].pk:
merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else:
merged_rows.append(model)
@ -151,30 +170,62 @@ class ModelTableProxy:
return other
@staticmethod
def own_table_columns(
model: Type["Model"], fields: List, nested: bool = False
def _get_not_nested_columns_from_fields(
model: Type["Model"],
fields: List,
column_names: List[str],
use_alias: bool = False,
) -> List[str]:
column_names = [col.name for col in model.Meta.table.columns]
fields = [model.get_column_alias(k) if not use_alias else k for k in fields]
columns = [name for name in fields if "__" not in name and name in column_names]
return columns
@staticmethod
def _get_nested_columns_from_fields(
model: Type["Model"], fields: List, use_alias: bool = False,
) -> List[str]:
model_name = f"{model.get_name()}__"
columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in fields
if f"{model.get_name()}__" in name
]
columns = [model.get_column_alias(k) if not use_alias else k for k in columns]
return columns
@staticmethod
def own_table_columns(
model: Type["Model"],
fields: List,
nested: bool = False,
use_alias: bool = False,
) -> List[str]:
column_names = [
model.get_column_name_from_alias(col.name) if use_alias else col.name
for col in model.Meta.table.columns
]
if not fields:
return column_names
if not nested:
columns = [
name for name in fields if "__" not in name and name in column_names
]
columns = ModelTableProxy._get_not_nested_columns_from_fields(
model, fields, column_names, use_alias
)
else:
model_name = f"{model.get_name()}__"
columns = [
name[(name.find(model_name) + len(model_name)) :] # noqa: E203
for name in fields
if f"{model.get_name()}__" in name
]
columns = ModelTableProxy._get_nested_columns_from_fields(
model, fields, use_alias
)
# if the model is in select and no columns in fields, all implied
if not columns:
columns = column_names
# always has to return pk column
if model.Meta.pkname not in columns:
columns.append(model.Meta.pkname)
pk_alias = (
model.get_column_alias(model.Meta.pkname)
if not use_alias
else model.Meta.pkname
)
if pk_alias not in columns:
columns.append(pk_alias)
return columns

View File

@ -134,8 +134,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def _extract_related_model_instead_of_field(
self, item: str
) -> Optional[Union["Model", List["Model"]]]:
if item in self._orm:
return self._orm.get(item)
alias = self.get_column_alias(item)
if alias in self._orm:
return self._orm.get(alias)
return None
def __eq__(self, other: object) -> bool:

View File

@ -44,7 +44,7 @@ class QueryClause:
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
if kwargs.get("pk"):
pk_name = self.model_cls.Meta.pkname
pk_name = self.model_cls.get_column_alias(self.model_cls.Meta.pkname)
kwargs[pk_name] = kwargs.pop("pk")
filter_clauses, select_related = self._populate_filter_clauses(**kwargs)
@ -83,7 +83,7 @@ class QueryClause:
else:
op = "exact"
column = self.table.columns[key]
column = self.table.columns[self.model_cls.get_column_alias(key)]
table = self.table
clause = self._process_column_clause_for_operator_and_value(

View File

@ -106,9 +106,11 @@ class SqlJoin:
self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause
)
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}"))
pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname)
self.order_bys.append(text(f"{alias}_{to_table}.{pkname_alias}"))
self_related_fields = model_cls.own_table_columns(
model_cls, self.fields, nested=True
model_cls, self.fields, nested=True,
)
self.columns.extend(
self.relation_manager(model_cls).prefixed_columns(
@ -125,12 +127,13 @@ class SqlJoin:
part: str,
) -> Tuple[str, str]:
if join_params.prev_model.Meta.model_fields[part].virtual or is_multi:
to_field = model_cls.resolve_relation_field(
to_field = model_cls.resolve_relation_name(
model_cls, join_params.prev_model
)
to_key = to_field.name
from_key = model_cls.Meta.pkname
to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias(model_cls.Meta.pkname)
else:
to_key = model_cls.Meta.pkname
from_key = part
to_key = model_cls.get_column_alias(model_cls.Meta.pkname)
from_key = join_params.prev_model.get_column_alias(part)
return to_key, from_key

View File

@ -40,7 +40,8 @@ class Query:
@property
def prefixed_pk_name(self) -> str:
return f"{self.table.name}.{self.model_cls.Meta.pkname}"
pkname_alias = self.model_cls.get_column_alias(self.model_cls.Meta.pkname)
return f"{self.table.name}.{pkname_alias}"
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self_related_fields = self.model_cls.own_table_columns(

View File

@ -70,12 +70,35 @@ class QuerySet:
return self.model.merge_instances_list(result_rows) # type: ignore
return result_rows
def _prepare_model_to_save(self, new_kwargs: dict) -> dict:
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
new_kwargs = self._translate_columns_to_aliases(new_kwargs)
return new_kwargs
def _populate_default_values(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_meta.model_fields.items():
if field_name not in new_kwargs and field.has_default():
new_kwargs[field_name] = field.get_default()
return new_kwargs
def _translate_columns_to_aliases(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_meta.model_fields.items():
if (
field_name in new_kwargs
and field.name is not None
and field.name != field_name
):
new_kwargs[field.name] = new_kwargs.pop(field_name)
return new_kwargs
def _translate_aliases_to_columns(self, new_kwargs: dict) -> dict:
for field_name, field in self.model_meta.model_fields.items():
if field.name in new_kwargs and field.name != field_name:
new_kwargs[field_name] = new_kwargs.pop(field.name)
return new_kwargs
def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict:
pkname = self.model_meta.pkname
pk = self.model_meta.model_fields[pkname]
@ -184,6 +207,7 @@ class QuerySet:
async def update(self, each: bool = False, **kwargs: Any) -> int:
self_fields = self.model.extract_db_own_fields()
updates = {k: v for k, v in kwargs.items() if k in self_fields}
updates = self._translate_columns_to_aliases(updates)
if not each and not self.filter_clauses:
raise QueryDefinitionError(
"You cannot update without filtering the queryset first. "
@ -278,9 +302,7 @@ class QuerySet:
async def create(self, **kwargs: Any) -> "Model":
new_kwargs = dict(**kwargs)
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
new_kwargs = self._prepare_model_to_save(new_kwargs)
expr = self.table.insert()
expr = expr.values(**new_kwargs)
@ -288,7 +310,7 @@ class QuerySet:
instance = self.model(**kwargs)
pk = await self.database.execute(expr)
pk_name = self.model_meta.pkname
pk_name = self.model.get_column_alias(self.model_meta.pkname)
if pk_name not in kwargs and pk_name in new_kwargs:
instance.pk = new_kwargs[self.model_meta.pkname]
if pk and isinstance(pk, self.model.pk_type()):
@ -300,9 +322,7 @@ class QuerySet:
ready_objects = []
for objt in objects:
new_kwargs = objt.dict()
new_kwargs = self._remove_pk_from_kwargs(new_kwargs)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._populate_default_values(new_kwargs)
new_kwargs = self._prepare_model_to_save(new_kwargs)
ready_objects.append(new_kwargs)
expr = self.table.insert()
@ -323,6 +343,8 @@ class QuerySet:
if pk_name not in columns:
columns.append(pk_name)
columns = [self.model.get_column_alias(k) for k in columns]
for objt in objects:
new_kwargs = objt.dict()
if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None:
@ -331,15 +353,24 @@ class QuerySet:
f"{self.model.__name__} has to have {pk_name} filled."
)
new_kwargs = self.model.substitute_models_with_pks(new_kwargs)
new_kwargs = self._translate_columns_to_aliases(new_kwargs)
new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns}
ready_objects.append(new_kwargs)
pk_column = self.model_meta.table.c.get(pk_name)
expr = self.table.update().where(pk_column == bindparam("new_" + pk_name))
pk_column = self.model_meta.table.c.get(self.model.get_column_alias(pk_name))
pk_column_name = self.model.get_column_alias(pk_name)
table_columns = [c.name for c in self.model_meta.table.c]
expr = self.table.update().where(
pk_column == bindparam("new_" + pk_column_name)
)
expr = expr.values(
**{k: bindparam("new_" + k) for k in columns if k != pk_name}
**{
k: bindparam("new_" + k)
for k in columns
if k != pk_column_name and k in table_columns
}
)
# databases bind params only where query is passed as string
# otherwise it just pases all data to values and results in unconsumed columns
# otherwise it just passes all data to values and results in unconsumed columns
expr = str(expr)
await self.database.execute_many(expr, ready_objects)

View File

@ -39,7 +39,7 @@ class RelationProxy(list):
def _set_queryset(self) -> "QuerySet":
owner_table = self.relation._owner.Meta.tablename
pkname = self.relation._owner.Meta.pkname
pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname)
pk_value = self.relation._owner.pk
if not pk_value:
raise RelationshipInstanceError(

157
tests/test_aliases.py Normal file
View File

@ -0,0 +1,157 @@
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Child(ormar.Model):
class Meta:
tablename = "children"
metadata = metadata
database = database
id: ormar.Integer(name='child_id', primary_key=True)
first_name: ormar.String(name='fname', max_length=100)
last_name: ormar.String(name='lname', max_length=100)
born_year: ormar.Integer(name='year_born', nullable=True)
class ArtistChildren(ormar.Model):
class Meta:
tablename = "children_x_artists"
metadata = metadata
database = database
class Artist(ormar.Model):
class Meta:
tablename = "artists"
metadata = metadata
database = database
id: ormar.Integer(name='artist_id', primary_key=True)
first_name: ormar.String(name='fname', max_length=100)
last_name: ormar.String(name='lname', max_length=100)
born_year: ormar.Integer(name='year')
children: ormar.ManyToMany(Child, through=ArtistChildren)
class Album(ormar.Model):
class Meta:
tablename = "music_albums"
metadata = metadata
database = database
id: ormar.Integer(name='album_id', primary_key=True)
name: ormar.String(name='album_name', max_length=100)
artist: ormar.ForeignKey(Artist, name='artist_id')
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_table_structure():
assert 'album_id' in [x.name for x in Album.Meta.table.columns]
assert 'album_name' in [x.name for x in Album.Meta.table.columns]
assert 'fname' in [x.name for x in Artist.Meta.table.columns]
assert 'lname' in [x.name for x in Artist.Meta.table.columns]
assert 'year' in [x.name for x in Artist.Meta.table.columns]
@pytest.mark.asyncio
async def test_working_with_aliases():
async with database:
async with database.transaction(force_rollback=True):
artist = await Artist.objects.create(first_name='Ted', last_name='Mosbey', born_year=1975)
await Album.objects.create(name="Aunt Robin", artist=artist)
await artist.children.create(first_name='Son', last_name='1', born_year=1990)
await artist.children.create(first_name='Son', last_name='2', born_year=1995)
album = await Album.objects.select_related('artist').first()
assert album.artist.last_name == 'Mosbey'
assert album.artist.id is not None
assert album.artist.first_name == 'Ted'
assert album.artist.born_year == 1975
assert album.name == 'Aunt Robin'
artist = await Artist.objects.select_related('children').get()
assert len(artist.children) == 2
assert artist.children[0].first_name == 'Son'
assert artist.children[1].last_name == '2'
await artist.update(last_name='Bundy')
await Artist.objects.filter(pk=artist.pk).update(born_year=1974)
artist = await Artist.objects.select_related('children').get()
assert artist.last_name == 'Bundy'
assert artist.born_year == 1974
artist = await Artist.objects.select_related('children').fields(
['first_name', 'last_name', 'born_year', 'child__first_name', 'child__last_name']).get()
assert artist.children[0].born_year is None
@pytest.mark.asyncio
async def test_bulk_operations_and_fields():
async with database:
d1 = Child(first_name='Daughter', last_name='1', born_year=1990)
d2 = Child(first_name='Daughter', last_name='2', born_year=1991)
await Child.objects.bulk_create([d1, d2])
children = await Child.objects.filter(first_name='Daughter').all()
assert len(children) == 2
assert children[0].last_name == '1'
for child in children:
child.born_year = child.born_year - 100
await Child.objects.bulk_update(children)
children = await Child.objects.filter(first_name='Daughter').all()
assert len(children) == 2
assert children[0].born_year == 1890
children = await Child.objects.fields(['first_name', 'last_name']).all()
assert len(children) == 2
for child in children:
assert child.born_year is None
await children[0].load()
await children[0].delete()
children = await Child.objects.all()
assert len(children) == 1
@pytest.mark.asyncio
async def test_working_with_aliases_get_or_create():
async with database:
async with database.transaction(force_rollback=True):
artist = await Artist.objects.get_or_create(first_name='Teddy', last_name='Bear', born_year=2020)
assert artist.pk is not None
artist2 = await Artist.objects.get_or_create(first_name='Teddy', last_name='Bear', born_year=2020)
assert artist == artist2
art3 = artist2.dict()
art3['born_year'] = 2019
await Artist.objects.update_or_create(**art3)
artist3 = await Artist.objects.get(last_name='Bear')
assert artist3.born_year == 2019
artists = await Artist.objects.all()
assert len(artists) == 1