Merge pull request #3 from collerek/relations

Relations
This commit is contained in:
collerek
2020-08-28 21:24:06 +07:00
committed by GitHub
23 changed files with 698 additions and 392 deletions

BIN
.coverage

Binary file not shown.

View File

@ -76,16 +76,16 @@ Since you join two times to the same table (categories) it won't work by default
But don't worry - ormar can handle situations like this, as it uses the Relationship Manager which has it's aliases defined for all relationships.
Each class is registered with the same instance of the RelationshipManager that you can access like this:
Each class is registered with the same instance of the AliasManager that you can access like this:
```python
SchoolClass._orm_relationship_manager
SchoolClass.alias_manager
```
It's the same object for all `Models`
```python
print(Teacher._orm_relationship_manager == Student._orm_relationship_manager)
print(Teacher.alias_manager == Student.alias_manager)
# will produce: True
```
@ -94,11 +94,11 @@ print(Teacher._orm_relationship_manager == Student._orm_relationship_manager)
You can even preview the alias used for any relation by passing two tables names.
```python
print(Teacher._orm_relationship_manager.resolve_relation_join(
print(Teacher.alias_manager.resolve_relation_join(
'students', 'categories'))
# will produce: KId1c6 (sample value)
print(Teacher._orm_relationship_manager.resolve_relation_join(
print(Teacher.alias_manager.resolve_relation_join(
'categories', 'students'))
# will produce: EFccd5 (sample value)
```

View File

@ -36,6 +36,6 @@ print('department' in course.__dict__)
# False <- related model is not stored on Course instance
print(course.department)
# Department(id=None, name='Science') <- Department model
# returned from RelationshipManager
# returned from AliasManager
print(course.department.name)
# Science

View File

@ -64,6 +64,6 @@ class BaseField:
@classmethod
def expand_relationship(
cls, value: Any, child: Union["Model", "NewBaseModel"]
cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True
) -> Any:
return value

View File

@ -54,6 +54,7 @@ def ForeignKey(
class ForeignKeyField(BaseField):
to: Type["Model"]
name: str
related_name: str
virtual: bool
@ -65,36 +66,35 @@ class ForeignKeyField(BaseField):
def validate(cls, value: Any) -> Any:
return value
# @property
# def __type__(self) -> Type[BaseModel]:
# return self.to.__pydantic_model__
# @classmethod
# def get_column_type(cls) -> sqlalchemy.Column:
# to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname]
# return to_column.column_type
@classmethod
def _extract_model_from_sequence(
cls, value: List, child: "Model"
cls, value: List, child: "Model", to_register: bool
) -> Union["Model", List["Model"]]:
return [cls.expand_relationship(val, child) for val in value]
return [cls.expand_relationship(val, child, to_register) for val in value]
@classmethod
def _register_existing_model(cls, value: "Model", child: "Model") -> "Model":
cls.register_relation(value, child)
def _register_existing_model(
cls, value: "Model", child: "Model", to_register: bool
) -> "Model":
if to_register:
cls.register_relation(value, child)
return value
@classmethod
def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model":
def _construct_model_from_dict(
cls, value: dict, child: "Model", to_register: bool
) -> "Model":
if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname:
value["__pk_only__"] = True
model = cls.to(**value)
cls.register_relation(model, child)
if to_register:
cls.register_relation(model, child)
return model
@classmethod
def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model":
def _construct_model_from_pk(
cls, value: Any, child: "Model", to_register: bool
) -> "Model":
if not isinstance(value, cls.to.pk_type()):
raise RelationshipInstanceError(
f"Relationship error - ForeignKey {cls.to.__name__} "
@ -102,19 +102,19 @@ class ForeignKeyField(BaseField):
f"while {type(value)} passed as a parameter."
)
model = create_dummy_instance(fk=cls.to, pk=value)
cls.register_relation(model, child)
if to_register:
cls.register_relation(model, child)
return model
@classmethod
def register_relation(cls, model: "Model", child: "Model") -> None:
child_model_name = cls.related_name or child.get_name()
model.Meta._orm_relationship_manager.add_relation(
model, child, child_model_name, virtual=cls.virtual
model._orm.add(
parent=model, child=child, child_name=cls.related_name, virtual=cls.virtual
)
@classmethod
def expand_relationship(
cls, value: Any, child: "Model"
cls, value: Any, child: "Model", to_register: bool = True
) -> Optional[Union["Model", List["Model"]]]:
if value is None:
return None
@ -127,5 +127,5 @@ class ForeignKeyField(BaseField):
model = constructors.get(
value.__class__.__name__, cls._construct_model_from_pk
)(value, child)
)(value, child, to_register)
return model

View File

@ -57,7 +57,7 @@ class String(ModelFieldFactory):
_bases = (pydantic.ConstrainedStr, BaseField)
_type = str
def __new__(
def __new__( # noqa CFQ002
cls,
*,
allow_blank: bool = False,
@ -231,7 +231,7 @@ class Decimal(ModelFieldFactory):
_bases = (pydantic.ConstrainedDecimal, BaseField)
_type = decimal.Decimal
def __new__(
def __new__( # noqa CFQ002
cls,
*,
minimum: float = None,

View File

@ -1,4 +1,5 @@
from ormar.models.model import Model
from ormar.models.newbasemodel import NewBaseModel
from ormar.models.newbasemodel import NewBaseModel # noqa I100
from ormar.models.model import Model # noqa I100
from ormar.models.metaclass import expand_reverse_relationships # noqa I100
__all__ = ["NewBaseModel", "Model"]
__all__ = ["NewBaseModel", "Model", "expand_reverse_relationships"]

View File

@ -10,12 +10,12 @@ from ormar import ForeignKey, ModelDefinitionError # noqa I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset import QuerySet
from ormar.relations import RelationshipManager
from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover
from ormar import Model
relationship_manager = RelationshipManager()
alias_manager = AliasManager()
class ModelMeta:
@ -26,19 +26,19 @@ class ModelMeta:
columns: List[sqlalchemy.Column]
pkname: str
model_fields: Dict[str, Union[BaseField, ForeignKey]]
_orm_relationship_manager: RelationshipManager
alias_manager: AliasManager
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
child_relation_name = (
field.to.get_name(title=True)
+ "_"
+ (field.related_name or (name.lower() + "s"))
)
reverse_name = child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name()
relationship_manager.add_relation_type(
relation_name, reverse_name, field, table_name
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
alias_manager.add_relation_type(field, table_name)
def reverse_field_not_already_registered(
child: Type["Model"], child_model_name: str, parent_model: Type["Model"]
) -> bool:
return (
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
)
@ -48,9 +48,8 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
child_model_name = model_field.related_name or model.get_name() + "s"
parent_model = model_field.to
child = model
if (
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
if reverse_field_not_already_registered(
child, child_model_name, parent_model
):
register_reverse_model_fields(parent_model, child, child_model_name)
@ -63,29 +62,42 @@ def register_reverse_model_fields(
)
def check_pk_column_validity(
field_name: str, field: BaseField, pkname: str
) -> Optional[str]:
if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only:
raise ModelDefinitionError("Primary key column cannot be pydantic only")
return field_name
def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
model_fields: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
columns = []
pkname = None
model_fields = {
field_name: field
for field_name, field in object_dict["__annotations__"].items()
if issubclass(field, BaseField)
}
for field_name, field in model_fields.items():
if field.primary_key:
if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only:
raise ModelDefinitionError("Primary key column cannot be pydantic only")
pkname = field_name
pkname = check_pk_column_validity(field_name, field, pkname)
if not field.pydantic_only:
columns.append(field.get_column(field_name))
if issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field, name)
register_relation_on_build(table_name, field)
return pkname, columns, model_fields
return pkname, columns
def populate_default_pydantic_field_value(
type_: Type[BaseField], field: str, attrs: dict
) -> dict:
def_value = type_.default_value()
curr_def_value = attrs.get(field, "NONE")
if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
attrs[field] = def_value
elif curr_def_value == "NONE" and type_.nullable:
attrs[field] = FieldInfo(default=None)
return attrs
def populate_pydantic_default_values(attrs: Dict) -> Dict:
@ -93,20 +105,70 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict:
if issubclass(type_, BaseField):
if type_.name is None:
type_.name = field
def_value = type_.default_value()
curr_def_value = attrs.get(field, "NONE")
if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
attrs[field] = def_value
elif curr_def_value == "NONE" and type_.nullable:
attrs[field] = FieldInfo(default=None)
attrs = populate_default_pydantic_field_value(type_, field, attrs)
return attrs
def extract_annotations_and_module(
attrs: dict, new_model: "ModelMetaclass", bases: Tuple
) -> dict:
annotations = attrs.get("__annotations__") or new_model.__annotations__
attrs["__annotations__"] = annotations
attrs = populate_pydantic_default_values(attrs)
attrs["__module__"] = attrs["__module__"] or bases[0].__module__
attrs["__annotations__"] = attrs["__annotations__"] or bases[0].__annotations__
return attrs
def populate_meta_orm_model_fields(
attrs: dict, new_model: Type["Model"]
) -> Type["Model"]:
model_fields = {
field_name: field
for field_name, field in attrs["__annotations__"].items()
if issubclass(field, BaseField)
}
new_model.Meta.model_fields = model_fields
return new_model
def populate_meta_tablename_columns_and_pk(
name: str, new_model: Type["Model"]
) -> Type["Model"]:
tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename or tablename
if hasattr(new_model.Meta, "columns"):
columns = new_model.Meta.table.columns
pkname = new_model.Meta.pkname
else:
pkname, columns = sqlalchemy_columns_from_model_fields(
new_model.Meta.model_fields, new_model.Meta.tablename
)
new_model.Meta.columns = columns
new_model.Meta.pkname = pkname
if not new_model.Meta.pkname:
raise ModelDefinitionError("Table has to have a primary key.")
return new_model
def populate_meta_sqlalchemy_table_if_required(
new_model: Type["Model"],
) -> Type["Model"]:
if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table(
new_model.Meta.tablename, new_model.Meta.metadata, *new_model.Meta.columns
)
return new_model
def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class Config(BaseConfig):
orm_mode = True
arbitrary_types_allowed = True
# extra = Extra.allow
return Config
@ -121,44 +183,17 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
if hasattr(new_model, "Meta"):
annotations = attrs.get("__annotations__") or new_model.__annotations__
attrs["__annotations__"] = annotations
attrs = populate_pydantic_default_values(attrs)
attrs = extract_annotations_and_module(attrs, new_model, bases)
new_model = populate_meta_orm_model_fields(attrs, new_model)
new_model = populate_meta_tablename_columns_and_pk(name, new_model)
new_model = populate_meta_sqlalchemy_table_if_required(new_model)
expand_reverse_relationships(new_model)
tablename = name.lower() + "s"
new_model.Meta.tablename = new_model.Meta.tablename or tablename
# sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
name, attrs, new_model.Meta.tablename
)
if hasattr(new_model.Meta, "model_fields") and not pkname:
model_fields = new_model.Meta.model_fields
for fieldname, field in new_model.Meta.model_fields.items():
if field.primary_key:
pkname = fieldname
columns = new_model.Meta.table.columns
if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table(
new_model.Meta.tablename, new_model.Meta.metadata, *columns
)
new_model.Meta.columns = columns
new_model.Meta.pkname = pkname
if not pkname:
raise ModelDefinitionError("Table has to have a primary key.")
new_model.Meta.model_fields = model_fields
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
expand_reverse_relationships(new_model)
new_model.Meta._orm_relationship_manager = relationship_manager
new_model.Meta.alias_manager = alias_manager
new_model.objects = QuerySet(new_model)
return new_model

View File

@ -1,4 +1,5 @@
from typing import Any, List
import itertools
from typing import Any, List, Tuple, Union
import sqlalchemy
@ -6,6 +7,21 @@ import ormar.queryset # noqa I100
from ormar.models import NewBaseModel # noqa I100
def group_related_list(list_: List) -> dict:
test_dict = dict()
grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0])
for key, group in grouped:
group_list = list(group)
new = [
"__".join(x.split("__")[1:]) for x in group_list if len(x.split("__")) > 1
]
if any("__" in x for x in new):
test_dict[key] = group_related_list(new)
else:
test_dict[key] = new
return test_dict
class Model(NewBaseModel):
__abstract__ = False
@ -14,22 +30,44 @@ class Model(NewBaseModel):
cls,
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
related_models: Any = None,
previous_table: str = None,
) -> "Model":
) -> Union["Model", Tuple["Model", dict]]:
item = {}
select_related = select_related or []
related_models = related_models or []
if select_related:
related_models = group_related_list(select_related)
table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join(
table_prefix = cls.Meta.alias_manager.resolve_relation_join(
previous_table, cls.Meta.table.name
)
previous_table = cls.Meta.table.name
for related in select_related:
if "__" in related:
first_part, remainder = related.split("__", 1)
item = cls.populate_nested_models_from_row(
item, row, related_models, previous_table
)
item = cls.extract_prefixed_table_columns(item, row, table_prefix)
instance = cls(**item) if item.get(cls.Meta.pkname, None) is not None else None
return instance
@classmethod
def populate_nested_models_from_row(
cls,
item: dict,
row: sqlalchemy.engine.ResultProxy,
related_models: Any,
previous_table: sqlalchemy.Table,
) -> dict:
for related in related_models:
if isinstance(related_models, dict) and related_models[related]:
first_part, remainder = related, related_models[related]
model_cls = cls.Meta.model_fields[first_part].to
child = model_cls.from_row(
row, select_related=[remainder], previous_table=previous_table
row, related_models=remainder, previous_table=previous_table
)
item[first_part] = child
else:
@ -37,17 +75,23 @@ class Model(NewBaseModel):
child = model_cls.from_row(row, previous_table=previous_table)
item[related] = child
return item
@classmethod
def extract_prefixed_table_columns(
cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str
) -> dict:
for column in cls.Meta.table.columns:
if column.name not in item:
item[column.name] = row[
f'{table_prefix + "_" if table_prefix else ""}{column.name}'
]
return cls(**item)
return item
async def save(self) -> "Model":
self_fields = self._extract_model_db_fields()
if self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
self_fields.pop(self.Meta.pkname, None)
expr = self.Meta.table.insert()
expr = expr.values(**self_fields)
@ -55,20 +99,18 @@ class Model(NewBaseModel):
setattr(self, self.Meta.pkname, item_id)
return self
async def update(self, **kwargs: Any) -> int:
async def update(self, **kwargs: Any) -> "Model":
if kwargs:
new_values = {**self.dict(), **kwargs}
self.from_dict(new_values)
self_fields = self._extract_model_db_fields()
self_fields.pop(self.Meta.pkname)
expr = (
self.Meta.table.update()
.values(**self_fields)
.where(self.pk_column == getattr(self, self.Meta.pkname))
)
result = await self.Meta.database.execute(expr)
return result
expr = self.Meta.table.update().values(**self_fields)
expr = expr.where(self.pk_column == getattr(self, self.Meta.pkname))
await self.Meta.database.execute(expr)
return self
async def delete(self) -> int:
expr = self.Meta.table.delete()

View File

@ -1,6 +1,5 @@
import copy
import inspect
from typing import List, Set, TYPE_CHECKING
from typing import List, Optional, Set, TYPE_CHECKING
import ormar
from ormar.fields.foreign_key import ForeignKeyField
@ -24,15 +23,15 @@ class ModelTableProxy:
@classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
model_dict = copy.deepcopy(model_dict)
for field in cls._extract_related_names():
if field in model_dict and model_dict.get(field) is not None:
field_value = model_dict.get(field, None)
if field_value is not None:
target_field = cls.Meta.model_fields[field]
target_pkname = target_field.to.Meta.pkname
if isinstance(model_dict.get(field), ormar.Model):
model_dict[field] = getattr(model_dict.get(field), target_pkname)
if isinstance(field_value, ormar.Model):
model_dict[field] = getattr(field_value, target_pkname)
else:
model_dict[field] = model_dict.get(field).get(target_pkname)
model_dict[field] = field_value.get(target_pkname)
return model_dict
@classmethod
@ -43,6 +42,18 @@ class ModelTableProxy:
related_names.add(name)
return related_names
@classmethod
def _extract_db_related_names(cls) -> Set:
related_names = set()
for name, field in cls.Meta.model_fields.items():
if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and not field.virtual
):
related_names.add(name)
return related_names
@classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested:
@ -62,18 +73,28 @@ class ModelTableProxy:
self_fields = {
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
}
for field in self._extract_related_names():
for field in self._extract_db_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), target_pk_name)
target_field = getattr(self, field)
self_fields[field] = getattr(target_field, target_pk_name, None)
return self_fields
@staticmethod
def resolve_relation_name(item: "Model", related: "Model") -> Optional[str]:
for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField):
# fastapi is creating clones of response model
# that's why it can be a subclass of the original model
# so we need to compare Meta too as this one is copied as is
if field.to == related.__class__ or field.to.Meta == related.Meta:
return name
@classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk:
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
if index > 0 and model.pk == merged_rows[-1].pk:
merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else:
merged_rows.append(model)
return merged_rows

View File

@ -20,9 +20,10 @@ from pydantic import BaseModel
import ormar # noqa I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy
from ormar.relations import RelationshipManager
from ormar.relations import AliasManager, RelationsManager
if TYPE_CHECKING: # pragma no cover
from ormar.models.model import Model
@ -34,7 +35,7 @@ if TYPE_CHECKING: # pragma no cover
class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
__slots__ = ("_orm_id", "_orm_saved")
__slots__ = ("_orm_id", "_orm_saved", "_orm")
if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]]
@ -45,7 +46,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
__tablename__: str
__metadata__: sqlalchemy.MetaData
__database__: databases.Database
_orm_relationship_manager: RelationshipManager
_orm_relationship_manager: AliasManager
_orm: RelationsManager
Meta: ModelMeta
# noinspection PyMissingConstructor
@ -53,13 +55,30 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "_orm_id", uuid.uuid4().hex)
object.__setattr__(self, "_orm_saved", False)
object.__setattr__(
self,
"_orm",
RelationsManager(
related_fields=[
field
for name, field in self.Meta.model_fields.items()
if issubclass(field, ForeignKeyField)
],
owner=self,
),
)
pk_only = kwargs.pop("__pk_only__", False)
if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk")
# build the models to set them and validate but don't register
kwargs = {
k: self._convert_json(
k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps"
k,
self.Meta.model_fields[k].expand_relationship(
v, self, to_register=False
),
"dumps",
)
for k, v in kwargs.items()
}
@ -71,17 +90,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set)
def __del__(self) -> None:
self.Meta._orm_relationship_manager.deregister(self)
# register the related models after initialization
for related in self._extract_related_names():
self.Meta.model_fields[related].expand_relationship(
kwargs.get(related), self, to_register=True
)
def __setattr__(self, name: str, value: Any) -> None:
relation_key = self.get_name(title=True) + "_" + name
if name in self.__slots__:
object.__setattr__(self, name, value)
elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value)
elif self.Meta._orm_relationship_manager.contains(relation_key, self):
self.Meta.model_fields[name].expand_relationship(value, self)
elif name in self._orm:
model = self.Meta.model_fields[name].expand_relationship(value, self)
self.__dict__[name] = model
else:
value = (
self._convert_json(name, value, "dumps")
@ -91,28 +113,25 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
super().__setattr__(name, value)
def __getattribute__(self, item: str) -> Any:
if item != "__fields__" and item in self.__fields__:
related = self._extract_related_model_instead_of_field(item)
if related:
return related
value = object.__getattribute__(self, item)
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"):
return object.__getattribute__(self, item)
elif item != "_extract_related_names" and item in self._extract_related_names():
return self._extract_related_model_instead_of_field(item)
elif item == "pk":
return self.__dict__.get(self.Meta.pkname, None)
elif item != "__fields__" and item in self.__fields__:
value = self.__dict__.get(item, None)
value = self._convert_json(item, value, "loads")
return value
return super().__getattribute__(item)
def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
return self._extract_related_model_instead_of_field(item)
def _extract_related_model_instead_of_field(
self, item: str
) -> Optional[Union["Model", List["Model"]]]:
relation_key = self.get_name(title=True) + "_" + item
if self.Meta._orm_relationship_manager.contains(relation_key, self):
return self.Meta._orm_relationship_manager.get(relation_key, self)
if item in self._orm:
return self._orm.get(item)
def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover
return False
return (
self._orm_id == other._orm_id
or self.__dict__ == other.__dict__
@ -124,14 +143,8 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
name = cls.__name__
if lower:
name = name.lower()
if title:
name = name.title()
return name
@property
def pk(self) -> Any:
return getattr(self, self.Meta.pkname)
@property
def pk_column(self) -> sqlalchemy.Column:
return self.Meta.table.primary_key.columns.values()[0]
@ -140,6 +153,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
def pk_type(cls) -> Any:
return cls.Meta.model_fields[cls.Meta.pkname].__type__
def remove(self, name: "Model") -> None:
self._orm.remove_parent(self, name)
def dict( # noqa A003
self,
*,
@ -167,17 +183,25 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
if self.Meta.model_fields[field].virtual and nested:
continue
if isinstance(nested_model, list):
dict_instance[field] = [x.dict(nested=True) for x in nested_model]
result = []
for model in nested_model:
try:
result.append(model.dict(nested=True))
except ReferenceError: # pragma no cover
continue
dict_instance[field] = result
elif nested_model is not None:
dict_instance[field] = nested_model.dict(nested=True)
else:
dict_instance[field] = None
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
def from_dict(self, value_dict: Dict) -> "Model":
for key, value in value_dict.items():
setattr(self, key, value)
return self
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
if not self._is_conversion_to_json_needed(column_name):
return value

View File

@ -37,13 +37,21 @@ class QueryClause:
def filter( # noqa: A003
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
if kwargs.get("pk"):
pk_name = self.model_cls.Meta.pkname
kwargs[pk_name] = kwargs.pop("pk")
filter_clauses, select_related = self._populate_filter_clauses(**kwargs)
return filter_clauses, select_related
def _populate_filter_clauses(
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
for key, value in kwargs.items():
table_prefix = ""
if "__" in key:
@ -73,24 +81,36 @@ class QueryClause:
column = self.table.columns[key]
table = self.table
value, has_escaped_character = self._escape_characters_in_clause(op, value)
if isinstance(value, ormar.Model):
value = value.pk
op_attr = FILTER_OPERATORS[op]
clause = getattr(column, op_attr)(value)
clause = self._compile_clause(
clause,
column,
table,
table_prefix,
modifiers={"escape": "\\" if has_escaped_character else None},
clause = self._process_column_clause_for_operator_and_value(
value, op, column, table, table_prefix
)
filter_clauses.append(clause)
return filter_clauses, select_related
def _process_column_clause_for_operator_and_value(
self,
value: Any,
op: str,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
) -> sqlalchemy.sql.expression.TextClause:
value, has_escaped_character = self._escape_characters_in_clause(op, value)
if isinstance(value, ormar.Model):
value = value.pk
op_attr = FILTER_OPERATORS[op]
clause = getattr(column, op_attr)(value)
clause = self._compile_clause(
clause,
column,
table,
table_prefix,
modifiers={"escape": "\\" if has_escaped_character else None},
)
return clause
def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str]
) -> Tuple[List[str], str, "Model"]:
@ -109,7 +129,7 @@ class QueryClause:
previous_table = model_cls.Meta.tablename
for part in related_parts:
current_table = model_cls.Meta.model_fields[part].to.Meta.tablename
manager = model_cls.Meta._orm_relationship_manager
manager = model_cls.Meta.alias_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table)
model_cls = model_cls.Meta.model_fields[part].to
previous_table = current_table

View File

@ -5,8 +5,7 @@ from sqlalchemy import text
import ormar # noqa I100
from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset.relationship_crawler import RelationshipCrawler
from ormar.relations import RelationshipManager
from ormar.relations import AliasManager
if TYPE_CHECKING: # pragma no cover
from ormar import Model
@ -44,22 +43,19 @@ class Query:
self.order_bys = None
@property
def relation_manager(self) -> RelationshipManager:
return self.model_cls.Meta._orm_relationship_manager
def relation_manager(self) -> AliasManager:
return self.model_cls.Meta.alias_manager
@property
def prefixed_pk_name(self) -> str:
return f"{self.table.name}.{self.model_cls.Meta.pkname}"
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self.columns = list(self.table.columns)
self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
self.order_bys = [text(self.prefixed_pk_name)]
self.select_from = self.table
start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
)
self._select_related = RelationshipCrawler().discover_relations(
self._select_related, prev_model=start_params.prev_model
)
self._select_related.sort(key=lambda item: (-len(item), item))
self._select_related.sort(key=lambda item: (item, -len(item)))
for item in self._select_related:
join_parameters = JoinParameters(
@ -77,10 +73,11 @@ class Query:
# print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters()
return expr, self._select_related
return expr
@staticmethod
def on_clause(
self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text:
left_part = f"{alias}_{to_clause}"
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
@ -92,7 +89,7 @@ class Query:
model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name
alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join(
alias = model_cls.Meta.alias_manager.resolve_relation_join(
join_params.from_table, to_table
)
if alias not in self.used_aliases:

View File

@ -47,7 +47,7 @@ class QuerySet:
offset=self.query_offset,
limit_count=self.limit_count,
)
exp, self._select_related = qry.build_select_expression()
exp = qry.build_select_expression()
return exp
def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003
@ -118,15 +118,25 @@ class QuerySet:
async def get(self, **kwargs: Any) -> "Model":
if kwargs:
return await self.filter(**kwargs).get()
else:
if not self.filter_clauses:
expr = self.build_select_expression().limit(2)
else:
expr = self.build_select_expression()
expr = self.build_select_expression().limit(2)
rows = await self.database.fetch_all(expr)
result_rows = [
self.model_cls.from_row(row, select_related=self._select_related)
for row in rows
]
rows = self.model_cls.merge_instances_list(result_rows)
if not rows:
raise NoMatch()
if len(rows) > 1:
raise MultipleMatches()
return self.model_cls.from_row(rows[0], select_related=self._select_related)
return rows[0]
async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003
if kwargs:
@ -138,7 +148,6 @@ class QuerySet:
self.model_cls.from_row(row, select_related=self._select_related)
for row in rows
]
result_rows = self.model_cls.merge_instances_list(result_rows)
return result_rows

View File

@ -1,87 +0,0 @@
from typing import List, TYPE_CHECKING, Type
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover
from ormar import Model
class RelationshipCrawler:
def __init__(self) -> None:
self._select_related = []
self.auto_related = []
self.already_checked = []
def discover_relations(
self, select_related: List, prev_model: Type["Model"]
) -> List[str]:
self._select_related = select_related
self._extract_auto_required_relations(prev_model=prev_model)
self._include_auto_related_models()
return self._select_related
@staticmethod
def _field_is_a_foreign_key_and_no_circular_reference(
field: Type[BaseField], field_name: str, rel_part: str
) -> bool:
return issubclass(field, ForeignKeyField) and field_name not in rel_part
def _field_qualifies_to_deeper_search(
self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any(
[x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]
)
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def _extract_auto_required_relations(
self,
prev_model: Type["Model"],
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
) -> None:
for field_name, field in prev_model.Meta.model_fields.items():
if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part
):
rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable:
if rel_part not in self._select_related:
split_tables = rel_part.split("__")
new_related = (
"__".join(split_tables[:-1])
if len(split_tables) > 1
else rel_part
)
self.auto_related.append(new_related)
rel_part = ""
elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part
):
self._extract_auto_required_relations(
prev_model=field.to,
rel_part=rel_part,
nested=True,
parent_virtual=field.virtual,
)
else:
self.already_checked.append(rel_part)
rel_part = ""
def _include_auto_related_models(self) -> None:
if self.auto_related:
new_joins = []
for join in self._select_related:
if not any([x.startswith(join) for x in self.auto_related]):
new_joins.append(join)
self._select_related = new_joins + self.auto_related

View File

@ -1,26 +1,33 @@
import pprint
import string
import uuid
from enum import Enum
from random import choices
from typing import List, TYPE_CHECKING, Union
from typing import List, Optional, TYPE_CHECKING, Type, Union
from weakref import proxy
import sqlalchemy
from sqlalchemy import text
import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
if TYPE_CHECKING: # pragma no cover
from ormar.models import NewBaseModel, Model
from ormar.models import Model
def get_table_alias() -> str:
return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
class RelationshipManager:
class RelationType(Enum):
PRIMARY = 1
REVERSE = 2
class AliasManager:
def __init__(self) -> None:
self._relations = dict()
self._aliases = dict()
@staticmethod
@ -34,86 +41,158 @@ class RelationshipManager:
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def add_relation_type(
self,
relations_key: str,
reverse_key: str,
field: ForeignKeyField,
table_name: str,
) -> None:
if relations_key not in self._relations:
self._relations[relations_key] = {"type": "primary"}
def add_relation_type(self, field: ForeignKeyField, table_name: str,) -> None:
if f"{table_name}_{field.to.Meta.tablename}" not in self._aliases:
self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias()
if reverse_key not in self._relations:
self._relations[reverse_key] = {"type": "reverse"}
if f"{field.to.Meta.tablename}_{table_name}" not in self._aliases:
self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias()
def deregister(self, model: "NewBaseModel") -> None:
for rel_type in self._relations.keys():
if model.get_name() in rel_type.lower():
if model._orm_id in self._relations[rel_type]:
del self._relations[rel_type][model._orm_id]
def add_relation(
self,
parent: "NewBaseModel",
child: "NewBaseModel",
child_model_name: str,
virtual: bool = False,
) -> None:
parent_id, child_id = parent._orm_id, child._orm_id
parent_name = parent.get_name(title=True)
child_name = (
child_model_name
if child.get_name() != child_model_name
else child.get_name() + "s"
)
if virtual:
child_name, parent_name = parent_name, child.get_name()
child_id, parent_id = parent_id, child_id
child, parent = parent, proxy(child)
child_name = child_name.lower() + "s"
else:
child = proxy(child)
parent_relation_name = parent_name.title() + "_" + child_name
parents_list = self._relations[parent_relation_name].setdefault(parent_id, [])
self.append_related_model(parents_list, child)
child_relation_name = child.get_name(title=True) + "_" + parent_name.lower()
children_list = self._relations[child_relation_name].setdefault(child_id, [])
self.append_related_model(children_list, parent)
@staticmethod
def append_related_model(relations_list: List["Model"], model: "Model") -> None:
for relation_child in relations_list:
try:
if relation_child.__same__(model):
return
except ReferenceError:
continue
relations_list.append(model)
def contains(self, relations_key: str, instance: "NewBaseModel") -> bool:
if relations_key in self._relations:
return instance._orm_id in self._relations[relations_key]
return False
def get(
self, relations_key: str, instance: "NewBaseModel"
) -> Union["Model", List["Model"]]:
if relations_key in self._relations:
if instance._orm_id in self._relations[relations_key]:
if self._relations[relations_key]["type"] == "primary":
return self._relations[relations_key][instance._orm_id][0]
return self._relations[relations_key][instance._orm_id]
def resolve_relation_join(self, from_table: str, to_table: str) -> str:
return self._aliases.get(f"{from_table}_{to_table}", "")
def __str__(self) -> str: # pragma no cover
return pprint.pformat(self._relations, indent=4, width=1)
class RelationProxy(list):
def __init__(self, relation: "Relation") -> None:
super(RelationProxy, self).__init__()
self.relation = relation
self._owner = self.relation.manager.owner
def remove(self, item: "Model") -> None:
super().remove(item)
rel_name = item.resolve_relation_name(item, self._owner)
item._orm._get(rel_name).remove(self._owner)
def append(self, item: "Model") -> None:
super().append(item)
def add(self, item: "Model") -> None:
rel_name = item.resolve_relation_name(item, self._owner)
setattr(item, rel_name, self._owner)
class Relation:
def __init__(self, manager: "RelationsManager", type_: RelationType) -> None:
self.manager = manager
self._owner = manager.owner
self._type = type_
self.related_models = (
RelationProxy(relation=self) if type_ == RelationType.REVERSE else None
)
def _find_existing(self, child: "Model") -> Optional[int]:
for ind, relation_child in enumerate(self.related_models[:]):
try:
if relation_child.__same__(child):
return ind
except ReferenceError: # pragma no cover
self.related_models.pop(ind)
return None
def add(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
self.related_models = child
self._owner.__dict__[relation_name] = child
else:
if self._find_existing(child) is None:
self.related_models.append(child)
rel = self._owner.__dict__.get(relation_name, [])
rel.append(child)
self._owner.__dict__[relation_name] = rel
def remove(self, child: "Model") -> None:
relation_name = self._owner.resolve_relation_name(self._owner, child)
if self._type == RelationType.PRIMARY:
if self.related_models.__same__(child):
self.related_models = None
del self._owner.__dict__[relation_name]
else:
position = self._find_existing(child)
if position is not None:
self.related_models.pop(position)
del self._owner.__dict__[relation_name][position]
def get(self) -> Union[List["Model"], "Model"]:
return self.related_models
def __repr__(self) -> str: # pragma no cover
return self.__str__()
return str(self.related_models)
class RelationsManager:
def __init__(
self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None
) -> None:
self.owner = owner
self._related_fields = related_fields or []
self._related_names = [field.name for field in self._related_fields]
self._relations = dict()
for field in self._related_fields:
self._add_relation(field)
def _add_relation(self, field: Type[ForeignKeyField]) -> None:
self._relations[field.name] = Relation(
manager=self,
type_=RelationType.PRIMARY if not field.virtual else RelationType.REVERSE,
)
def __contains__(self, item: str) -> bool:
return item in self._related_names
def get(self, name: str) -> Optional[Union[List["Model"], "Model"]]:
relation = self._relations.get(name, None)
if relation:
return relation.get()
def _get(self, name: str) -> Optional[Relation]:
relation = self._relations.get(name, None)
if relation:
return relation
@staticmethod
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None:
to_field = next(
(
field
for field in child._orm._related_fields
if field.to == parent.__class__ or field.to.Meta == parent.Meta
),
None,
)
if not to_field: # pragma no cover
raise RelationshipInstanceError(
f"Model {child.__class__} does not have "
f"reference to model {parent.__class__}"
)
to_name = to_field.name
if virtual:
child_name, to_name = to_name, child_name or child.get_name()
child, parent = parent, proxy(child)
else:
child_name = child_name or child.get_name() + "s"
child = proxy(child)
parent_relation = parent._orm._get(child_name)
if not parent_relation:
ormar.models.expand_reverse_relationships(child.__class__)
name = parent.resolve_relation_name(parent, child)
field = parent.Meta.model_fields[name]
parent._orm._add_relation(field)
parent_relation = parent._orm._get(child_name)
parent_relation.add(child)
child._orm._get(to_name).add(parent)
def remove(self, name: str, child: "Model") -> None:
relation = self._get(name)
relation.remove(child)
@staticmethod
def remove_parent(item: "Model", name: Union[str, "Model"]) -> None:
related_model = name
name = item.resolve_relation_name(item, related_model)
if name in item._orm:
relation_name = item.resolve_relation_name(related_model, item)
item._orm.remove(name, related_model)
related_model._orm.remove(relation_name, item)

View File

@ -22,7 +22,7 @@ class Example(ormar.Model):
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=200, default='aaa')
name: ormar.String(max_length=200, default="aaa")
created: ormar.DateTime(default=datetime.datetime.now)
created_day: ormar.Date(default=datetime.date.today)
created_time: ormar.Time(default=time)

View File

@ -1,11 +1,11 @@
import gc
import databases
import pytest
import sqlalchemy
from pydantic import ValidationError
import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
@ -131,9 +131,11 @@ async def test_model_crud():
album1 = await Album.objects.get(name="Malibu")
assert album1.pk == 1
assert album1.tracks is None
assert album1.tracks == []
await Track.objects.create(album={"id": track.album.pk}, title="The Bird2", position=4)
await Track.objects.create(
album={"id": track.album.pk}, title="The Bird2", position=4
)
@pytest.mark.asyncio
@ -164,6 +166,47 @@ async def test_select_related():
assert len(tracks) == 6
@pytest.mark.asyncio
async def test_model_removal_from_relations():
async with database:
album = Album(name="Chichi")
await album.save()
track1 = Track(album=album, title="The Birdman", position=1)
track2 = Track(album=album, title="Superman", position=2)
track3 = Track(album=album, title="Wonder Woman", position=3)
await track1.save()
await track2.save()
await track3.save()
assert len(album.tracks) == 3
album.tracks.remove(track1)
assert len(album.tracks) == 2
assert track1.album is None
await track1.update()
track1 = await Track.objects.get(title="The Birdman")
assert track1.album is None
album.tracks.add(track1)
assert len(album.tracks) == 3
assert track1.album == album
await track1.update()
track1 = await Track.objects.select_related("album__tracks").get(
title="The Birdman"
)
album = await Album.objects.select_related("tracks").get(name="Chichi")
assert track1.album == album
track1.remove(album)
assert track1.album is None
assert len(album.tracks) == 2
track2.remove(album)
assert track2.album is None
assert len(album.tracks) == 1
@pytest.mark.asyncio
async def test_fk_filter():
async with database:
@ -182,8 +225,8 @@ async def test_fk_filter():
tracks = (
await Track.objects.select_related("album")
.filter(album__name="Fantasies")
.all()
.filter(album__name="Fantasies")
.all()
)
assert len(tracks) == 3
for track in tracks:
@ -191,8 +234,8 @@ async def test_fk_filter():
tracks = (
await Track.objects.select_related("album")
.filter(album__name__icontains="fan")
.all()
.filter(album__name__icontains="fan")
.all()
)
assert len(tracks) == 3
for track in tracks:
@ -234,8 +277,8 @@ async def test_multiple_fk():
members = (
await Member.objects.select_related("team__org")
.filter(team__org__ident="ACME Ltd")
.all()
.filter(team__org__ident="ACME Ltd")
.all()
)
assert len(members) == 4
for member in members:
@ -254,8 +297,8 @@ async def test_pk_filter():
tracks = (
await Track.objects.select_related("album")
.filter(position=2, album__name="Test")
.all()
.filter(position=2, album__name="Test")
.all()
)
assert len(tracks) == 1

View File

@ -54,7 +54,9 @@ class ExampleModel2(Model):
@pytest.fixture()
def example():
return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5))
return ExampleModel(
pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)
)
def test_not_nullable_field_is_required():
@ -110,6 +112,7 @@ def test_sqlalchemy_table_is_created(example):
def test_no_pk_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example3"
@ -120,6 +123,7 @@ def test_no_pk_in_model_definition():
def test_two_pks_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example3"
@ -131,6 +135,7 @@ def test_two_pks_in_model_definition():
def test_setting_pk_column_as_pydantic_only_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example4"
@ -141,6 +146,7 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition():
def test_decimal_error_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example5"
@ -151,6 +157,7 @@ def test_decimal_error_in_model_definition():
def test_string_error_in_model_definition():
with pytest.raises(ModelDefinitionError):
class ExampleModel2(Model):
class Meta:
tablename = "example6"

View File

@ -28,7 +28,7 @@ class User(ormar.Model):
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100, default='')
name: ormar.String(max_length=100, default="")
class Product(ormar.Model):

View File

@ -79,7 +79,7 @@ async def create_category(category: Category):
@app.put("/items/{item_id}")
async def get_item(item_id: int, item: Item):
item_db = await Item.objects.get(pk=item_id)
return {"updated_rows": await item_db.update(**item.dict())}
return await item_db.update(**item.dict())
@app.delete("/items/{item_id}")
@ -105,7 +105,7 @@ def test_all_endpoints():
item.name = "New name"
response = client.put(f"/items/{item.pk}", json=item.dict())
assert response.json().get("updated_rows") == 1
assert response.json() == item.dict()
response = client.get("/items/")
items = [Item(**item) for item in response.json()]

View File

@ -0,0 +1,110 @@
import asyncio
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class Department(ormar.Model):
class Meta:
tablename = "departments"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True, autoincrement=False)
name: ormar.String(max_length=100)
class SchoolClass(ormar.Model):
class Meta:
tablename = "schoolclasses"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
class Category(ormar.Model):
class Meta:
tablename = "categories"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
department: ormar.ForeignKey(Department, nullable=False)
class Student(ormar.Model):
class Meta:
tablename = "students"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
schoolclass: ormar.ForeignKey(SchoolClass)
category: ormar.ForeignKey(Category, nullable=True)
class Teacher(ormar.Model):
class Meta:
tablename = "teachers"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
name: ormar.String(max_length=100)
schoolclass: ormar.ForeignKey(SchoolClass)
category: ormar.ForeignKey(Category, nullable=True)
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(autouse=True, scope="module")
async def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
department = await Department.objects.create(id=1, name="Math Department")
department2 = await Department.objects.create(id=2, name="Law Department")
class1 = await SchoolClass.objects.create(name="Math")
class2 = await SchoolClass.objects.create(name="Logic")
category = await Category.objects.create(name="Foreign", department=department)
category2 = await Category.objects.create(name="Domestic", department=department2)
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema():
async with database:
classes = await SchoolClass.objects.select_related(
["teachers__category__department", "students"]
).all()
assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane"
assert len(classes[0].dict().get("students")) == 2
assert classes[0].teachers[0].category.department.name == "Law Department"
assert classes[0].students[0].category.pk is not None
assert classes[0].students[0].category.name is None
await classes[0].students[0].category.load()
await classes[0].students[0].category.department.load()
assert classes[0].students[0].category.department.name == "Math Department"

View File

@ -79,11 +79,14 @@ async def create_test_database():
metadata.drop_all(engine)
metadata.create_all(engine)
department = await Department.objects.create(id=1, name="Math Department")
department2 = await Department.objects.create(id=2, name="Law Department")
class1 = await SchoolClass.objects.create(name="Math", department=department)
class2 = await SchoolClass.objects.create(name="Logic", department=department2)
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class1)
await Student.objects.create(name="Judy", category=category2, schoolclass=class1)
await Student.objects.create(name="Jack", category=category2, schoolclass=class2)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield
metadata.drop_all(engine)
@ -100,15 +103,15 @@ async def test_model_multiple_instances_of_same_table_in_schema():
assert len(classes[0].dict().get("students")) == 2
# related fields of main model are only populated by pk
# unless there is a required foreign key somewhere along the way
# since department is required for schoolclass it was pre loaded (again)
# but you can load them anytime
# since it's going from schoolclass => teacher => schoolclass (same class) department is already populated
assert classes[0].students[0].schoolclass.name == "Math"
assert classes[0].students[0].schoolclass.department.name is None
await classes[0].students[0].schoolclass.department.load()
assert classes[0].students[0].schoolclass.department.name == "Math Department"
await classes[1].students[0].schoolclass.department.load()
assert classes[1].students[0].schoolclass.department.name == "Law Department"
@pytest.mark.asyncio
async def test_right_tables_join():
@ -130,5 +133,7 @@ async def test_multiple_reverse_related_objects():
["teachers__category", "students__category"]
).all()
assert classes[0].name == "Math"
assert classes[0].students[1].name == "Jack"
assert classes[0].students[1].name == "Judy"
assert classes[0].students[0].category.name == "Foreign"
assert classes[0].students[1].category.name == "Domestic"
assert classes[0].teachers[0].category.name == "Domestic"