add fixes for fastapi model clones, add functionality to add and remove models to relation, add relation proxy, fix all tests, adding values also to pydantic model __dict__some refactors

This commit is contained in:
collerek
2020-08-26 22:24:25 +02:00
parent a9f88e8f8f
commit c5389023b8
17 changed files with 260 additions and 118 deletions

BIN
.coverage

Binary file not shown.

View File

@ -64,6 +64,6 @@ class BaseField:
@classmethod
def expand_relationship(
cls, value: Any, child: Union["Model", "NewBaseModel"]
cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True
) -> Any:
return value

View File

@ -68,25 +68,33 @@ class ForeignKeyField(BaseField):
@classmethod
def _extract_model_from_sequence(
cls, value: List, child: "Model"
cls, value: List, child: "Model", to_register: bool
) -> Union["Model", List["Model"]]:
return [cls.expand_relationship(val, child) for val in value]
return [cls.expand_relationship(val, child, to_register) for val in value]
@classmethod
def _register_existing_model(cls, value: "Model", child: "Model") -> "Model":
cls.register_relation(value, child)
def _register_existing_model(
cls, value: "Model", child: "Model", to_register: bool
) -> "Model":
if to_register:
cls.register_relation(value, child)
return value
@classmethod
def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model":
def _construct_model_from_dict(
cls, value: dict, child: "Model", to_register: bool
) -> "Model":
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname:
value["__pk_only__"] = True
model = cls.to(**value)
cls.register_relation(model, child)
if to_register:
cls.register_relation(model, child)
return model
@classmethod
def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model":
def _construct_model_from_pk(
cls, value: Any, child: "Model", to_register: bool
) -> "Model":
if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError(
f"Relationship error - ForeignKey {cls.to.__name__} "
@ -94,7 +102,8 @@ class ForeignKeyField(BaseField):
f"while {type(value)} passed as a parameter."
)
model = create_dummy_instance(fk=cls.to, pk=value)
cls.register_relation(model, child)
if to_register:
cls.register_relation(model, child)
return model
@classmethod
@ -105,7 +114,7 @@ class ForeignKeyField(BaseField):
@classmethod
def expand_relationship(
cls, value: Any, child: "Model"
cls, value: Any, child: "Model", to_register: bool = True
) -> Optional[Union["Model", List["Model"]]]:
if value is None:
return None
@ -118,5 +127,5 @@ class ForeignKeyField(BaseField):
model = constructors.get(
value.__class__.__name__, cls._construct_model_from_pk
)(value, child)
)(value, child, to_register)
return model

View File

@ -1,4 +1,5 @@
from ormar.models.newbasemodel import NewBaseModel
from ormar.models.model import Model
from ormar.models.metaclass import expand_reverse_relationships
__all__ = ["NewBaseModel", "Model"]
__all__ = ["NewBaseModel", "Model", "expand_reverse_relationships"]

View File

@ -29,17 +29,8 @@ class ModelMeta:
alias_manager: AliasManager
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
child_relation_name = (
field.to.get_name(title=True)
+ "_"
+ (field.related_name or (name.lower() + "s"))
)
reverse_name = child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name()
relationship_manager.add_relation_type(
relation_name, reverse_name, field, table_name
)
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
relationship_manager.add_relation_type(field, table_name)
def expand_reverse_relationships(model: Type["Model"]) -> None:
@ -64,15 +55,10 @@ def register_reverse_model_fields(
def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
model_fields: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
columns = []
pkname = None
model_fields = {
field_name: field
for field_name, field in object_dict["__annotations__"].items()
if issubclass(field, BaseField)
}
for field_name, field in model_fields.items():
if field.primary_key:
if pkname is not None:
@ -83,9 +69,9 @@ def sqlalchemy_columns_from_model_fields(
if not field.pydantic_only:
columns.append(field.get_column(field_name))
if issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field, name)
register_relation_on_build(table_name, field)
return pkname, columns, model_fields
return pkname, columns
def populate_pydantic_default_values(attrs: Dict) -> Dict:
@ -125,21 +111,29 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
attrs["__annotations__"] = annotations
attrs = populate_pydantic_default_values(attrs)
attrs["__module__"] = attrs["__module__"] or bases[0].__module__
attrs["__annotations__"] = (
attrs["__annotations__"] or bases[0].__annotations__
)
tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename or tablename
# sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
name, attrs, new_model.Meta.tablename
)
model_fields = {
field_name: field
for field_name, field in attrs["__annotations__"].items()
if issubclass(field, BaseField)
}
if hasattr(new_model.Meta, "model_fields") and not pkname:
model_fields = new_model.Meta.model_fields
for fieldname, field in new_model.Meta.model_fields.items():
if field.primary_key:
pkname = fieldname
if hasattr(new_model.Meta, "columns"):
columns = new_model.Meta.table.columns
pkname = new_model.Meta.pkname
else:
pkname, columns = sqlalchemy_columns_from_model_fields(
model_fields, new_model.Meta.tablename
)
if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table(
@ -153,10 +147,11 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
raise ModelDefinitionError("Table has to have a primary key.")
new_model.Meta.model_fields = model_fields
expand_reverse_relationships(new_model)
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
expand_reverse_relationships(new_model)
new_model.Meta.alias_manager = relationship_manager
new_model.objects = QuerySet(new_model)

View File

@ -69,7 +69,8 @@ class Model(NewBaseModel):
async def save(self) -> "Model":
self_fields = self._extract_model_db_fields()
if self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
self_fields.pop(self.Meta.pkname, None)
expr = self.Meta.table.insert()
expr = expr.values(**self_fields)
@ -77,7 +78,7 @@ class Model(NewBaseModel):
setattr(self, self.Meta.pkname, item_id)
return self
async def update(self, **kwargs: Any) -> int:
async def update(self, **kwargs: Any) -> "Model":
if kwargs:
new_values = {**self.dict(), **kwargs}
self.from_dict(new_values)
@ -89,8 +90,8 @@ class Model(NewBaseModel):
.values(**self_fields)
.where(self.pk_column == getattr(self, self.Meta.pkname))
)
result = await self.Meta.database.execute(expr)
return result
await self.Meta.database.execute(expr)
return self
async def delete(self) -> int:
expr = self.Meta.table.delete()

View File

@ -24,7 +24,6 @@ class ModelTableProxy:
@classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
model_dict = copy.deepcopy(model_dict)
for field in cls._extract_related_names():
if field in model_dict and model_dict.get(field) is not None:
target_field = cls.Meta.model_fields[field]
@ -76,10 +75,19 @@ class ModelTableProxy:
}
for field in self._extract_db_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), target_pk_name)
target_field = getattr(self, field)
self_fields[field] = getattr(target_field, target_pk_name, None)
return self_fields
@staticmethod
def resolve_relation_name(item: "Model", related: "Model"):
for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField):
# fastapi is creating clones of response model that's why it can be a subclass
# of the original one so we need to compare Meta too
if field.to == related.__class__ or field.to.Meta == related.Meta:
return name
@classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = []

View File

@ -71,9 +71,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
pk_only = kwargs.pop("__pk_only__", False)
if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register
kwargs = {
k: self._convert_json(
k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps"
k,
self.Meta.model_fields[k].expand_relationship(
v, self, to_register=False
),
"dumps",
)
for k, v in kwargs.items()
}
@ -85,13 +90,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set)
# register the related models after initialization
for related in self._extract_related_names():
self.Meta.model_fields[related].expand_relationship(
kwargs.get(related), self, to_register=True
)
def __setattr__(self, name: str, value: Any) -> None:
if name in self.__slots__:
object.__setattr__(self, name, value)
elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value)
elif name in self._orm:
self.Meta.model_fields[name].expand_relationship(value, self)
model = self.Meta.model_fields[name].expand_relationship(value, self)
self.__dict__[name] = model
else:
value = (
self._convert_json(name, value, "dumps")
@ -113,19 +125,13 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
return value
return super().__getattribute__(item)
# def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
# return self._extract_related_model_instead_of_field(item)
def _extract_related_model_instead_of_field(
self, item: str
) -> Optional[Union["Model", List["Model"]]]:
# relation_key = self.get_name(title=True) + "_" + item
if item in self._orm:
return self._orm.get(item)
def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover
return False
return (
self._orm_id == other._orm_id
or self.__dict__ == other.__dict__
@ -137,8 +143,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
name = cls.__name__
if lower:
name = name.lower()
if title:
name = name.title()
return name
@property
@ -149,6 +153,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def pk_type(cls) -> Any:
return cls.Meta.model_fields[cls.Meta.pkname].__type__
def remove(self, name: "Model"):
self._orm.remove_parent(self, name)
def dict( # noqa A003
self,
*,
@ -176,14 +183,23 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if self.Meta.model_fields[field].virtual and nested:
continue
if isinstance(nested_model, list):
dict_instance[field] = [x.dict(nested=True) for x in nested_model]
result = []
for model in nested_model:
try:
result.append(model.dict(nested=True))
except ReferenceError: # pragma no cover
continue
dict_instance[field] = result
elif nested_model is not None:
dict_instance[field] = nested_model.dict(nested=True)
else:
dict_instance[field] = None
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
def from_dict(self, value_dict: Dict) -> "Model":
for key, value in value_dict.items():
setattr(self, key, value)
return self
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
if not self._is_conversion_to_json_needed(column_name):

View File

@ -69,10 +69,11 @@ class Query:
# print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters()
return expr, self._select_related
return expr
@staticmethod
def on_clause(
self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text:
left_part = f"{alias}_{to_clause}"
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"

View File

@ -47,7 +47,7 @@ class QuerySet:
offset=self.query_offset,
limit_count=self.limit_count,
)
exp, self._select_related = qry.build_select_expression()
exp = qry.build_select_expression()
return exp
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -118,15 +118,25 @@ class QuerySet:
async def get(self, **kwargs: Any) -> "Model":
if kwargs:
return await self.filter(**kwargs).get()
else:
if not self.filter_clauses:
expr = self.build_select_expression().limit(2)
else:
expr = self.build_select_expression()
expr = self.build_select_expression().limit(2)
rows = await self.database.fetch_all(expr)
result_rows = [
self.model_cls.from_row(row, select_related=self._select_related)
for row in rows
]
rows = self.model_cls.merge_instances_list(result_rows)
if not rows:
raise NoMatch()
if len(rows) > 1:
raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related)
return rows[0]
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
if kwargs:

View File

@ -2,12 +2,13 @@ import string
import uuid
from enum import Enum
from random import choices
from typing import List, TYPE_CHECKING, Type
from typing import List, TYPE_CHECKING, Type, Union, Optional
from weakref import proxy
import sqlalchemy
from sqlalchemy import text
import ormar
from ormar.exceptions import RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
@ -26,7 +27,6 @@ class RelationType(Enum):
class AliasManager:
def __init__(self) -> None:
self._relations = dict()
self._aliases = dict()
@staticmethod
@ -40,54 +40,83 @@ class AliasManager:
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def add_relation_type(
self,
relations_key: str,
reverse_key: str,
field: ForeignKeyField,
table_name: str,
) -> None:
if relations_key not in self._relations:
def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None:
if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases:
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
if reverse_key not in self._relations:
if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases:
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
return self._aliases.get(f"{from_table}_{to_table}", "")
class Relation:
def __init__(self, type_: RelationType) -> None:
self._type = type_
self.related_models = [] if type_ == RelationType.REVERSE else None
class RelationProxy(list):
def __init__(self, relation: "Relation"):
super(RelationProxy, self).__init__()
self.relation = relation
self._owner = self.relation.manager.owner
def _find_existing(self, child):
for ind, relation_child in enumerate(self.related_models):
def remove(self, item: "Model"):
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner)
def append(self, item: "Model"):
super().append(item)
def add(self, item):
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
class Relation:
def __init__(self, manager: "RelationsManager", type_: RelationType) -> None:
self.manager = manager
self._owner = manager.owner
self._type = type_
self.related_models = (
RelationProxy(relation=self) if type_ == RelationType.REVERSE else None
)
def _find_existing(self, child) -> Optional[int]:
for ind, relation_child in enumerate(self.related_models[:]):
try:
if relation_child.__same__(child):
return ind
except ReferenceError: # pragma no cover
continue
self.related_models.pop(ind)
return None
def add(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
self.related_models = child
self._owner.__dict__[relation_name] = child
else:
if self._find_existing(child) is None:
self.related_models.append(child)
rel = self._owner.__dict__.get(relation_name, [])
rel.append(child)
self._owner.__dict__[relation_name] = rel
# def remove(self, child: "Model") -> None:
# if self._type == RelationType.PRIMARY:
# self.related_models = None
# else:
# position = self._find_existing(child)
# if position is not None:
# self.related_models.pop(position)
def remove(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
if self.related_models.__same__(child):
self.related_models = None
del self._owner.__dict__[relation_name]
else:
position = self._find_existing(child)
if position is not None:
self.related_models.pop(position)
del self._owner.__dict__[relation_name][position]
def get(self):
def get(self) -> Union[List["Model"], "Model"]:
return self.related_models
def __repr__(self): # pragma no cover
return str(self.related_models)
class RelationsManager:
def __init__(
@ -98,21 +127,23 @@ class RelationsManager:
self._related_names = [field.name for field in self._related_fields]
self._relations = dict()
for field in self._related_fields:
self._relations[field.name] = Relation(
type_=RelationType.PRIMARY
if not field.virtual
else RelationType.REVERSE
)
self._add_relation(field)
def _add_relation(self, field):
self._relations[field.name] = Relation(
manager=self,
type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE,
)
def __contains__(self, item):
return item in self._related_names
def get(self, name):
def get(self, name) -> Optional[Union[List["Model"], "Model"]]:
relation = self._relations.get(name, None)
if relation:
return relation.get()
def _get(self, name):
def _get(self, name) -> Optional[Relation]:
relation = self._relations.get(name, None)
if relation:
return relation
@ -122,7 +153,7 @@ class RelationsManager:
(
field
for field in child._orm._related_fields
if field.to == parent.__class__
if field.to == parent.__class__ or field.to.Meta == parent.Meta
),
None,
)
@ -140,5 +171,25 @@ class RelationsManager:
child_name = child_name or child.get_name() + "s"
child = proxy(child)
parent._orm._get(child_name).add(child)
parent_relation = parent._orm._get(child_name)
if not parent_relation:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
parent_relation.add(child)
child._orm._get(to_name).add(parent)
def remove(self, name: str, child: "Model"):
relation = self._get(name)
relation.remove(child)
@staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]):
related_model = name
name = item.resolve_relation_name(item, related_model)
if name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model)
related_model._orm.remove(relation_name, item)

View File

@ -22,7 +22,7 @@ class Example(ormar.Model):
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=200, default='aaa')
name: ormar.String(max_length=200, default="aaa")
created: ormar.DateTime(default=datetime.datetime.now)
created_day: ormar.Date(default=datetime.date.today)
created_time: ormar.Time(default=time)

View File

@ -1,11 +1,11 @@
import gc
import databases
import pytest
import sqlalchemy
from pydantic import ValidationError
import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
@ -133,7 +133,9 @@ async def test_model_crud():
assert album1.pk == 1
assert album1.tracks == []
await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4)
await Track.objects.create(
album={"id": track.album.pk}, title="The Bird2", position=4
)
@pytest.mark.asyncio
@ -164,6 +166,47 @@ async def test_select_related():
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_model_removal_from_relations():
async with database:
album = Album(name="Chichi")
await album.save()
track1 = Track(album=album, title="The Birdman", position=1)
track2 = Track(album=album, title="Superman", position=2)
track3 = Track(album=album, title="Wonder Woman", position=3)
await track1.save()
await track2.save()
await track3.save()
assert len(album.tracks) == 3
album.tracks.remove(track1)
assert len(album.tracks) == 2
assert track1.album is None
await track1.update()
track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None
album.tracks.add(track1)
assert len(album.tracks) == 3
assert track1.album == album
await track1.update()
track1 = await Track.objects.select_related("album__tracks").get(
title="The Birdman"
)
album = await Album.objects.select_related("tracks").get(name="Chichi")
assert track1.album == album
track1.remove(album)
assert track1.album is None
assert len(album.tracks) == 2
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
@pytest.mark.asyncio
async def test_fk_filter():
async with database:
@ -182,8 +225,8 @@ async def test_fk_filter():
tracks = (
await Track.objects.select_related("album")
.filter(album__name="Fantasies")
.all()
.filter(album__name="Fantasies")
.all()
)
assert len(tracks) == 3
for track in tracks:
@ -191,8 +234,8 @@ async def test_fk_filter():
tracks = (
await Track.objects.select_related("album")
.filter(album__name__icontains="fan")
.all()
.filter(album__name__icontains="fan")
.all()
)
assert len(tracks) == 3
for track in tracks:
@ -234,8 +277,8 @@ async def test_multiple_fk():
members = (
await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd")
.all()
.filter(team__org__ident="ACME Ltd")
.all()
)
assert len(members) == 4
for member in members:
@ -254,8 +297,8 @@ async def test_pk_filter():
tracks = (
await Track.objects.select_related("album")
.filter(position=2, album__name="Test")
.all()
.filter(position=2, album__name="Test")
.all()
)
assert len(tracks) == 1

View File

@ -54,7 +54,9 @@ class ExampleModel2(Model):
@pytest.fixture()
def example():
return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5))
return ExampleModel(
pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)
)
def test_not_nullable_field_is_required():
@ -110,6 +112,7 @@ def test_sqlalchemy_table_is_created(example):
def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example3"
@ -120,6 +123,7 @@ def test_no_pk_in_model_definition():
def test_two_pks_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example3"
@ -131,6 +135,7 @@ def test_two_pks_in_model_definition():
def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example4"
@ -141,6 +146,7 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example5"
@ -151,6 +157,7 @@ def test_decimal_error_in_model_definition():
def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example6"

View File

@ -28,7 +28,7 @@ class User(ormar.Model):
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100, default='')
name: ormar.String(max_length=100, default="")
class Product(ormar.Model):

View File

@ -79,7 +79,7 @@ async def create_category(category: Category):
@app.put("/items/{item_id}")
async def get_item(item_id: int, item: Item):
item_db = await Item.objects.get(pk=item_id)
return {"updated_rows": await item_db.update(**item.dict())}
return await item_db.update(**item.dict())
@app.delete("/items/{item_id}")
@ -105,7 +105,7 @@ def test_all_endpoints():
item.name = "New name"
response = client.put(f"/items/{item.pk}", json=item.dict())
assert response.json().get("updated_rows") == 1
assert response.json() == item.dict()
response = client.get("/items/")
items = [Item(**item) for item in response.json()]

View File

@ -101,10 +101,10 @@ async def test_model_multiple_instances_of_same_table_in_schema():
assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane"
assert len(classes[0].dict().get("students")) == 2
assert classes[0].teachers[0].category.department.name == 'Law Department'
assert classes[0].teachers[0].category.department.name == "Law Department"
assert classes[0].students[0].category.pk is not None
assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load()
await classes[0].students[0].category.department.load()
assert classes[0].students[0].category.department.name == 'Math Department'
assert classes[0].students[0].category.department.name == "Math Department"