add fixes for fastapi model clones, add functionality to add and remove models to relation, add relation proxy, fix all tests, adding values also to pydantic model __dict__some refactors

This commit is contained in:
collerek
2020-08-26 22:24:25 +02:00
parent a9f88e8f8f
commit c5389023b8
17 changed files with 260 additions and 118 deletions

BIN
.coverage

Binary file not shown.

View File

@ -64,6 +64,6 @@ class BaseField:
@classmethod @classmethod
def expand_relationship( def expand_relationship(
cls, value: Any, child: Union["Model", "NewBaseModel"] cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True
) -> Any: ) -> Any:
return value return value

View File

@ -68,25 +68,33 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(
cls, value: List, child: "Model" cls, value: List, child: "Model", to_register: bool
) -> Union["Model", List["Model"]]: ) -> Union["Model", List["Model"]]:
return [cls.expand_relationship(val, child) for val in value] return [cls.expand_relationship(val, child, to_register) for val in value]
@classmethod @classmethod
def _register_existing_model(cls, value: "Model", child: "Model") -> "Model": def _register_existing_model(
cls, value: "Model", child: "Model", to_register: bool
) -> "Model":
if to_register:
cls.register_relation(value, child) cls.register_relation(value, child)
return value return value
@classmethod @classmethod
def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model": def _construct_model_from_dict(
cls, value: dict, child: "Model", to_register: bool
) -> "Model":
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname:
value["__pk_only__"] = True value["__pk_only__"] = True
model = cls.to(**value) model = cls.to(**value)
if to_register:
cls.register_relation(model, child) cls.register_relation(model, child)
return model return model
@classmethod @classmethod
def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model": def _construct_model_from_pk(
cls, value: Any, child: "Model", to_register: bool
) -> "Model":
if not isinstance(value, cls.to.pk_type()): if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError( raise RelationshipInstanceError(
f"Relationship error - ForeignKey {cls.to.__name__} " f"Relationship error - ForeignKey {cls.to.__name__} "
@ -94,6 +102,7 @@ class ForeignKeyField(BaseField):
f"while {type(value)} passed as a parameter." f"while {type(value)} passed as a parameter."
) )
model = create_dummy_instance(fk=cls.to, pk=value) model = create_dummy_instance(fk=cls.to, pk=value)
if to_register:
cls.register_relation(model, child) cls.register_relation(model, child)
return model return model
@ -105,7 +114,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def expand_relationship( def expand_relationship(
cls, value: Any, child: "Model" cls, value: Any, child: "Model", to_register: bool = True
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if value is None: if value is None:
return None return None
@ -118,5 +127,5 @@ class ForeignKeyField(BaseField):
model = constructors.get( model = constructors.get(
value.__class__.__name__, cls._construct_model_from_pk value.__class__.__name__, cls._construct_model_from_pk
)(value, child) )(value, child, to_register)
return model return model

View File

@ -1,4 +1,5 @@
from ormar.models.newbasemodel import NewBaseModel from ormar.models.newbasemodel import NewBaseModel
from ormar.models.model import Model from ormar.models.model import Model
from ormar.models.metaclass import expand_reverse_relationships
__all__ = ["NewBaseModel", "Model"] __all__ = ["NewBaseModel", "Model", "expand_reverse_relationships"]

View File

@ -29,17 +29,8 @@ class ModelMeta:
alias_manager: AliasManager alias_manager: AliasManager
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
child_relation_name = ( relationship_manager.add_relation_type(field, table_name)
field.to.get_name(title=True)
+ "_"
+ (field.related_name or (name.lower() + "s"))
)
reverse_name = child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name()
relationship_manager.add_relation_type(
relation_name, reverse_name, field, table_name
)
def expand_reverse_relationships(model: Type["Model"]) -> None: def expand_reverse_relationships(model: Type["Model"]) -> None:
@ -64,15 +55,10 @@ def register_reverse_model_fields(
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str model_fields: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
columns = [] columns = []
pkname = None pkname = None
model_fields = {
field_name: field
for field_name, field in object_dict["__annotations__"].items()
if issubclass(field, BaseField)
}
for field_name, field in model_fields.items(): for field_name, field in model_fields.items():
if field.primary_key: if field.primary_key:
if pkname is not None: if pkname is not None:
@ -83,9 +69,9 @@ def sqlalchemy_columns_from_model_fields(
if not field.pydantic_only: if not field.pydantic_only:
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
if issubclass(field, ForeignKeyField): if issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field, name) register_relation_on_build(table_name, field)
return pkname, columns, model_fields return pkname, columns
def populate_pydantic_default_values(attrs: Dict) -> Dict: def populate_pydantic_default_values(attrs: Dict) -> Dict:
@ -125,21 +111,29 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
attrs["__annotations__"] = annotations attrs["__annotations__"] = annotations
attrs = populate_pydantic_default_values(attrs) attrs = populate_pydantic_default_values(attrs)
attrs["__module__"] = attrs["__module__"] or bases[0].__module__
attrs["__annotations__"] = (
attrs["__annotations__"] or bases[0].__annotations__
)
tablename = name.lower() + "s" tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename or tablename new_model.Meta.tablename = new_model.Meta.tablename or tablename
# sqlalchemy table creation # sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( model_fields = {
name, attrs, new_model.Meta.tablename field_name: field
) for field_name, field in attrs["__annotations__"].items()
if issubclass(field, BaseField)
}
if hasattr(new_model.Meta, "model_fields") and not pkname: if hasattr(new_model.Meta, "columns"):
model_fields = new_model.Meta.model_fields
for fieldname, field in new_model.Meta.model_fields.items():
if field.primary_key:
pkname = fieldname
columns = new_model.Meta.table.columns columns = new_model.Meta.table.columns
pkname = new_model.Meta.pkname
else:
pkname, columns = sqlalchemy_columns_from_model_fields(
model_fields, new_model.Meta.tablename
)
if not hasattr(new_model.Meta, "table"): if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table( new_model.Meta.table = sqlalchemy.Table(
@ -153,10 +147,11 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
raise ModelDefinitionError("Table has to have a primary key.") raise ModelDefinitionError("Table has to have a primary key.")
new_model.Meta.model_fields = model_fields new_model.Meta.model_fields = model_fields
expand_reverse_relationships(new_model)
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
expand_reverse_relationships(new_model)
new_model.Meta.alias_manager = relationship_manager new_model.Meta.alias_manager = relationship_manager
new_model.objects = QuerySet(new_model) new_model.objects = QuerySet(new_model)

View File

@ -69,7 +69,8 @@ class Model(NewBaseModel):
async def save(self) -> "Model": async def save(self) -> "Model":
self_fields = self._extract_model_db_fields() self_fields = self._extract_model_db_fields()
if self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
self_fields.pop(self.Meta.pkname, None) self_fields.pop(self.Meta.pkname, None)
expr = self.Meta.table.insert() expr = self.Meta.table.insert()
expr = expr.values(**self_fields) expr = expr.values(**self_fields)
@ -77,7 +78,7 @@ class Model(NewBaseModel):
setattr(self, self.Meta.pkname, item_id) setattr(self, self.Meta.pkname, item_id)
return self return self
async def update(self, **kwargs: Any) -> int: async def update(self, **kwargs: Any) -> "Model":
if kwargs: if kwargs:
new_values = {**self.dict(), **kwargs} new_values = {**self.dict(), **kwargs}
self.from_dict(new_values) self.from_dict(new_values)
@ -89,8 +90,8 @@ class Model(NewBaseModel):
.values(**self_fields) .values(**self_fields)
.where(self.pk_column == getattr(self, self.Meta.pkname)) .where(self.pk_column == getattr(self, self.Meta.pkname))
) )
result = await self.Meta.database.execute(expr) await self.Meta.database.execute(expr)
return result return self
async def delete(self) -> int: async def delete(self) -> int:
expr = self.Meta.table.delete() expr = self.Meta.table.delete()

View File

@ -24,7 +24,6 @@ class ModelTableProxy:
@classmethod @classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict: def substitute_models_with_pks(cls, model_dict: dict) -> dict:
model_dict = copy.deepcopy(model_dict)
for field in cls._extract_related_names(): for field in cls._extract_related_names():
if field in model_dict and model_dict.get(field) is not None: if field in model_dict and model_dict.get(field) is not None:
target_field = cls.Meta.model_fields[field] target_field = cls.Meta.model_fields[field]
@ -76,10 +75,19 @@ class ModelTableProxy:
} }
for field in self._extract_db_related_names(): for field in self._extract_db_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None: target_field = getattr(self, field)
self_fields[field] = getattr(getattr(self, field), target_pk_name) self_fields[field] = getattr(target_field, target_pk_name, None)
return self_fields return self_fields
@staticmethod
def resolve_relation_name(item: "Model", related: "Model"):
for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField):
# fastapi is creating clones of response model that's why it can be a subclass
# of the original one so we need to compare Meta too
if field.to == related.__class__ or field.to.Meta == related.Meta:
return name
@classmethod @classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = [] merged_rows = []

View File

@ -71,9 +71,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
pk_only = kwargs.pop("__pk_only__", False) pk_only = kwargs.pop("__pk_only__", False)
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register
kwargs = { kwargs = {
k: self._convert_json( k: self._convert_json(
k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps" k,
self.Meta.model_fields[k].expand_relationship(
v, self, to_register=False
),
"dumps",
) )
for k, v in kwargs.items() for k, v in kwargs.items()
} }
@ -85,13 +90,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set) object.__setattr__(self, "__fields_set__", fields_set)
# register the related models after initialization
for related in self._extract_related_names():
self.Meta.model_fields[related].expand_relationship(
kwargs.get(related), self, to_register=True
)
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
if name in self.__slots__: if name in self.__slots__:
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == "pk": elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
elif name in self._orm: elif name in self._orm:
self.Meta.model_fields[name].expand_relationship(value, self) model = self.Meta.model_fields[name].expand_relationship(value, self)
self.__dict__[name] = model
else: else:
value = ( value = (
self._convert_json(name, value, "dumps") self._convert_json(name, value, "dumps")
@ -113,19 +125,13 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
return value return value
return super().__getattribute__(item) return super().__getattribute__(item)
# def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
# return self._extract_related_model_instead_of_field(item)
def _extract_related_model_instead_of_field( def _extract_related_model_instead_of_field(
self, item: str self, item: str
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
# relation_key = self.get_name(title=True) + "_" + item
if item in self._orm: if item in self._orm:
return self._orm.get(item) return self._orm.get(item)
def __same__(self, other: "Model") -> bool: def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover
return False
return ( return (
self._orm_id == other._orm_id self._orm_id == other._orm_id
or self.__dict__ == other.__dict__ or self.__dict__ == other.__dict__
@ -137,8 +143,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
name = cls.__name__ name = cls.__name__
if lower: if lower:
name = name.lower() name = name.lower()
if title:
name = name.title()
return name return name
@property @property
@ -149,6 +153,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def pk_type(cls) -> Any: def pk_type(cls) -> Any:
return cls.Meta.model_fields[cls.Meta.pkname].__type__ return cls.Meta.model_fields[cls.Meta.pkname].__type__
def remove(self, name: "Model"):
self._orm.remove_parent(self, name)
def dict( # noqa A003 def dict( # noqa A003
self, self,
*, *,
@ -176,14 +183,23 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if self.Meta.model_fields[field].virtual and nested: if self.Meta.model_fields[field].virtual and nested:
continue continue
if isinstance(nested_model, list): if isinstance(nested_model, list):
dict_instance[field] = [x.dict(nested=True) for x in nested_model] result = []
for model in nested_model:
try:
result.append(model.dict(nested=True))
except ReferenceError: # pragma no cover
continue
dict_instance[field] = result
elif nested_model is not None: elif nested_model is not None:
dict_instance[field] = nested_model.dict(nested=True) dict_instance[field] = nested_model.dict(nested=True)
else:
dict_instance[field] = None
return dict_instance return dict_instance
def from_dict(self, value_dict: Dict) -> None: def from_dict(self, value_dict: Dict) -> "Model":
for key, value in value_dict.items(): for key, value in value_dict.items():
setattr(self, key, value) setattr(self, key, value)
return self
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
if not self._is_conversion_to_json_needed(column_name): if not self._is_conversion_to_json_needed(column_name):

View File

@ -69,10 +69,11 @@ class Query:
# print(expr.compile(compile_kwargs={"literal_binds": True})) # print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters() self._reset_query_parameters()
return expr, self._select_related return expr
@staticmethod
def on_clause( def on_clause(
self, previous_alias: str, alias: str, from_clause: str, to_clause: str, previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text: ) -> text:
left_part = f"{alias}_{to_clause}" left_part = f"{alias}_{to_clause}"
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"

View File

@ -47,7 +47,7 @@ class QuerySet:
offset=self.query_offset, offset=self.query_offset,
limit_count=self.limit_count, limit_count=self.limit_count,
) )
exp, self._select_related = qry.build_select_expression() exp = qry.build_select_expression()
return exp return exp
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -118,15 +118,25 @@ class QuerySet:
async def get(self, **kwargs: Any) -> "Model": async def get(self, **kwargs: Any) -> "Model":
if kwargs: if kwargs:
return await self.filter(**kwargs).get() return await self.filter(**kwargs).get()
else:
if not self.filter_clauses:
expr = self.build_select_expression().limit(2) expr = self.build_select_expression().limit(2)
else:
expr = self.build_select_expression()
rows = await self.database.fetch_all(expr) rows = await self.database.fetch_all(expr)
result_rows = [
self.model_cls.from_row(row, select_related=self._select_related)
for row in rows
]
rows = self.model_cls.merge_instances_list(result_rows)
if not rows: if not rows:
raise NoMatch() raise NoMatch()
if len(rows) > 1: if len(rows) > 1:
raise MultipleMatches() raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related) return rows[0]
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
if kwargs: if kwargs:

View File

@ -2,12 +2,13 @@ import string
import uuid import uuid
from enum import Enum from enum import Enum
from random import choices from random import choices
from typing import List, TYPE_CHECKING, Type from typing import List, TYPE_CHECKING, Type, Union, Optional
from weakref import proxy from weakref import proxy
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
import ormar
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField # noqa I100 from ormar.fields.foreign_key import ForeignKeyField # noqa I100
@ -26,7 +27,6 @@ class RelationType(Enum):
class AliasManager: class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._relations = dict()
self._aliases = dict() self._aliases = dict()
@staticmethod @staticmethod
@ -40,54 +40,83 @@ class AliasManager:
def prefixed_table_name(alias: str, name: str) -> text: def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}") return text(f"{name} {alias}_{name}")
def add_relation_type( def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None:
self, if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases:
relations_key: str,
reverse_key: str,
field: ForeignKeyField,
table_name: str,
) -> None:
if relations_key not in self._relations:
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
if reverse_key not in self._relations: if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases:
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
def resolve_relation_join(self, from_table: str, to_table: str) -> str: def resolve_relation_join(self, from_table: str, to_table: str) -> str:
return self._aliases.get(f"{from_table}_{to_table}", "") return self._aliases.get(f"{from_table}_{to_table}", "")
class Relation: class RelationProxy(list):
def __init__(self, type_: RelationType) -> None: def __init__(self, relation: "Relation"):
self._type = type_ super(RelationProxy, self).__init__()
self.related_models = [] if type_ == RelationType.REVERSE else None self.relation = relation
self._owner = self.relation.manager.owner
def _find_existing(self, child): def remove(self, item: "Model"):
for ind, relation_child in enumerate(self.related_models): super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner)
def append(self, item: "Model"):
super().append(item)
def add(self, item):
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
class Relation:
def __init__(self, manager: "RelationsManager", type_: RelationType) -> None:
self.manager = manager
self._owner = manager.owner
self._type = type_
self.related_models = (
RelationProxy(relation=self) if type_ == RelationType.REVERSE else None
)
def _find_existing(self, child) -> Optional[int]:
for ind, relation_child in enumerate(self.related_models[:]):
try: try:
if relation_child.__same__(child): if relation_child.__same__(child):
return ind return ind
except ReferenceError: # pragma no cover except ReferenceError: # pragma no cover
continue self.related_models.pop(ind)
return None return None
def add(self, child: "Model") -> None: def add(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY: if self._type == RelationType.PRIMARY:
self.related_models = child self.related_models = child
self._owner.__dict__[relation_name] = child
else: else:
if self._find_existing(child) is None: if self._find_existing(child) is None:
self.related_models.append(child) self.related_models.append(child)
rel = self._owner.__dict__.get(relation_name, [])
rel.append(child)
self._owner.__dict__[relation_name] = rel
# def remove(self, child: "Model") -> None: def remove(self, child: "Model") -> None:
# if self._type == RelationType.PRIMARY: relation_name = self._owner.resolve_relation_name(self._owner, child)
# self.related_models = None if self._type == RelationType.PRIMARY:
# else: if self.related_models.__same__(child):
# position = self._find_existing(child) self.related_models = None
# if position is not None: del self._owner.__dict__[relation_name]
# self.related_models.pop(position) else:
position = self._find_existing(child)
if position is not None:
self.related_models.pop(position)
del self._owner.__dict__[relation_name][position]
def get(self): def get(self) -> Union[List["Model"], "Model"]:
return self.related_models return self.related_models
def __repr__(self): # pragma no cover
return str(self.related_models)
class RelationsManager: class RelationsManager:
def __init__( def __init__(
@ -98,21 +127,23 @@ class RelationsManager:
self._related_names = [field.name for field in self._related_fields] self._related_names = [field.name for field in self._related_fields]
self._relations = dict() self._relations = dict()
for field in self._related_fields: for field in self._related_fields:
self._add_relation(field)
def _add_relation(self, field):
self._relations[field.name] = Relation( self._relations[field.name] = Relation(
type_=RelationType.PRIMARY manager=self,
if not field.virtual type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE,
else RelationType.REVERSE
) )
def __contains__(self, item): def __contains__(self, item):
return item in self._related_names return item in self._related_names
def get(self, name): def get(self, name) -> Optional[Union[List["Model"], "Model"]]:
relation = self._relations.get(name, None) relation = self._relations.get(name, None)
if relation: if relation:
return relation.get() return relation.get()
def _get(self, name): def _get(self, name) -> Optional[Relation]:
relation = self._relations.get(name, None) relation = self._relations.get(name, None)
if relation: if relation:
return relation return relation
@ -122,7 +153,7 @@ class RelationsManager:
( (
field field
for field in child._orm._related_fields for field in child._orm._related_fields
if field.to == parent.__class__ if field.to == parent.__class__ or field.to.Meta == parent.Meta
), ),
None, None,
) )
@ -140,5 +171,25 @@ class RelationsManager:
child_name = child_name or child.get_name() + "s" child_name = child_name or child.get_name() + "s"
child = proxy(child) child = proxy(child)
parent._orm._get(child_name).add(child) parent_relation = parent._orm._get(child_name)
if not parent_relation:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
parent_relation.add(child)
child._orm._get(to_name).add(parent) child._orm._get(to_name).add(parent)
def remove(self, name: str, child: "Model"):
relation = self._get(name)
relation.remove(child)
@staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]):
related_model = name
name = item.resolve_relation_name(item, related_model)
if name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model)
related_model._orm.remove(relation_name, item)

View File

@ -22,7 +22,7 @@ class Example(ormar.Model):
database = database database = database
id: ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=200, default='aaa') name: ormar.String(max_length=200, default="aaa")
created: ormar.DateTime(default=datetime.datetime.now) created: ormar.DateTime(default=datetime.datetime.now)
created_day: ormar.Date(default=datetime.date.today) created_day: ormar.Date(default=datetime.date.today)
created_time: ormar.Time(default=time) created_time: ormar.Time(default=time)

View File

@ -1,11 +1,11 @@
import gc
import databases import databases
import pytest import pytest
import sqlalchemy import sqlalchemy
from pydantic import ValidationError
import ormar import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -133,7 +133,9 @@ async def test_model_crud():
assert album1.pk == 1 assert album1.pk == 1
assert album1.tracks == [] assert album1.tracks == []
await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4) await Track.objects.create(
album={"id": track.album.pk}, title="The Bird2", position=4
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -164,6 +166,47 @@ async def test_select_related():
assert len(tracks) == 6 assert len(tracks) == 6
@pytest.mark.asyncio
async def test_model_removal_from_relations():
async with database:
album = Album(name="Chichi")
await album.save()
track1 = Track(album=album, title="The Birdman", position=1)
track2 = Track(album=album, title="Superman", position=2)
track3 = Track(album=album, title="Wonder Woman", position=3)
await track1.save()
await track2.save()
await track3.save()
assert len(album.tracks) == 3
album.tracks.remove(track1)
assert len(album.tracks) == 2
assert track1.album is None
await track1.update()
track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None
album.tracks.add(track1)
assert len(album.tracks) == 3
assert track1.album == album
await track1.update()
track1 = await Track.objects.select_related("album__tracks").get(
title="The Birdman"
)
album = await Album.objects.select_related("tracks").get(name="Chichi")
assert track1.album == album
track1.remove(album)
assert track1.album is None
assert len(album.tracks) == 2
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fk_filter(): async def test_fk_filter():
async with database: async with database:

View File

@ -54,7 +54,9 @@ class ExampleModel2(Model):
@pytest.fixture() @pytest.fixture()
def example(): def example():
return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)) return ExampleModel(
pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)
)
def test_not_nullable_field_is_required(): def test_not_nullable_field_is_required():
@ -110,6 +112,7 @@ def test_sqlalchemy_table_is_created(example):
def test_no_pk_in_model_definition(): def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example3" tablename = "example3"
@ -120,6 +123,7 @@ def test_no_pk_in_model_definition():
def test_two_pks_in_model_definition(): def test_two_pks_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example3" tablename = "example3"
@ -131,6 +135,7 @@ def test_two_pks_in_model_definition():
def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example4" tablename = "example4"
@ -141,6 +146,7 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
def test_decimal_error_in_model_definition(): def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example5" tablename = "example5"
@ -151,6 +157,7 @@ def test_decimal_error_in_model_definition():
def test_string_error_in_model_definition(): def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError): with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model): class ExampleModel2(Model):
class Meta: class Meta:
tablename = "example6" tablename = "example6"

View File

@ -28,7 +28,7 @@ class User(ormar.Model):
database = database database = database
id: ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100, default='') name: ormar.String(max_length=100, default="")
class Product(ormar.Model): class Product(ormar.Model):

View File

@ -79,7 +79,7 @@ async def create_category(category: Category):
@app.put("/items/{item_id}") @app.put("/items/{item_id}")
async def get_item(item_id: int, item: Item): async def get_item(item_id: int, item: Item):
item_db = await Item.objects.get(pk=item_id) item_db = await Item.objects.get(pk=item_id)
return {"updated_rows": await item_db.update(**item.dict())} return await item_db.update(**item.dict())
@app.delete("/items/{item_id}") @app.delete("/items/{item_id}")
@ -105,7 +105,7 @@ def test_all_endpoints():
item.name = "New name" item.name = "New name"
response = client.put(f"/items/{item.pk}", json=item.dict()) response = client.put(f"/items/{item.pk}", json=item.dict())
assert response.json().get("updated_rows") == 1 assert response.json() == item.dict()
response = client.get("/items/") response = client.get("/items/")
items = [Item(**item) for item in response.json()] items = [Item(**item) for item in response.json()]

View File

@ -101,10 +101,10 @@ async def test_model_multiple_instances_of_same_table_in_schema():
assert classes[0].name == "Math" assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane" assert classes[0].students[0].name == "Jane"
assert len(classes[0].dict().get("students")) == 2 assert len(classes[0].dict().get("students")) == 2
assert classes[0].teachers[0].category.department.name == 'Law Department' assert classes[0].teachers[0].category.department.name == "Law Department"
assert classes[0].students[0].category.pk is not None assert classes[0].students[0].category.pk is not None
assert classes[0].students[0].category.name is None assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load() await classes[0].students[0].category.load()
await classes[0].students[0].category.department.load() await classes[0].students[0].category.department.load()
assert classes[0].students[0].category.department.name == 'Math Department' assert classes[0].students[0].category.department.name == "Math Department"