merge from master, simplify props in meta inheritance

This commit is contained in:
collerek
2020-12-14 20:56:58 +01:00
23 changed files with 540 additions and 296 deletions

View File

@ -1,3 +1,11 @@
# 0.7.4
* Allow multiple relations to the same related model/table.
* Fix for wrong relation column used in many_to_many relation joins (fix [#73][#73])
* Fix for wrong relation population for m2m relations when also fk relation present for same model.
* Add check if user provide related_name if there are multiple relations to same table on one model.
* More eager cleaning of the dead weak proxy models.
# 0.7.3 # 0.7.3
* Fix for setting fetching related model with UUDI pk, which is a string in raw (fix [#71][#71]) * Fix for setting fetching related model with UUDI pk, which is a string in raw (fix [#71][#71])
@ -193,4 +201,5 @@ Add queryset level methods
[#60]: https://github.com/collerek/ormar/issues/60 [#60]: https://github.com/collerek/ormar/issues/60
[#68]: https://github.com/collerek/ormar/issues/68 [#68]: https://github.com/collerek/ormar/issues/68
[#70]: https://github.com/collerek/ormar/issues/70 [#70]: https://github.com/collerek/ormar/issues/70
[#71]: https://github.com/collerek/ormar/issues/71 [#71]: https://github.com/collerek/ormar/issues/71
[#73]: https://github.com/collerek/ormar/issues/73

View File

@ -25,6 +25,7 @@ class BaseField(FieldInfo):
""" """
__type__ = None __type__ = None
related_name = None
column_type: sqlalchemy.Column column_type: sqlalchemy.Column
constraints: List = [] constraints: List = []
@ -222,7 +223,11 @@ class BaseField(FieldInfo):
@classmethod @classmethod
def expand_relationship( def expand_relationship(
cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True cls,
value: Any,
child: Union["Model", "NewBaseModel"],
to_register: bool = True,
relation_name: str = None,
) -> Any: ) -> Any:
""" """
Function overwritten for relations, in basic field the value is returned as is. Function overwritten for relations, in basic field the value is returned as is.

View File

@ -162,7 +162,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(
cls, value: List, child: "Model", to_register: bool cls, value: List, child: "Model", to_register: bool, relation_name: str
) -> List["Model"]: ) -> List["Model"]:
""" """
Takes a list of Models and registers them on parent. Takes a list of Models and registers them on parent.
@ -180,13 +180,18 @@ class ForeignKeyField(BaseField):
:rtype: List["Model"] :rtype: List["Model"]
""" """
return [ return [
cls.expand_relationship(val, child, to_register) # type: ignore cls.expand_relationship( # type: ignore
value=val,
child=child,
to_register=to_register,
relation_name=relation_name,
)
for val in value for val in value
] ]
@classmethod @classmethod
def _register_existing_model( def _register_existing_model(
cls, value: "Model", child: "Model", to_register: bool cls, value: "Model", child: "Model", to_register: bool, relation_name: str
) -> "Model": ) -> "Model":
""" """
Takes already created instance and registers it for parent. Takes already created instance and registers it for parent.
@ -204,12 +209,12 @@ class ForeignKeyField(BaseField):
:rtype: Model :rtype: Model
""" """
if to_register: if to_register:
cls.register_relation(value, child) cls.register_relation(model=value, child=child, relation_name=relation_name)
return value return value
@classmethod @classmethod
def _construct_model_from_dict( def _construct_model_from_dict(
cls, value: dict, child: "Model", to_register: bool cls, value: dict, child: "Model", to_register: bool, relation_name: str
) -> "Model": ) -> "Model":
""" """
Takes a dictionary, creates a instance and registers it for parent. Takes a dictionary, creates a instance and registers it for parent.
@ -231,12 +236,12 @@ class ForeignKeyField(BaseField):
value["__pk_only__"] = True value["__pk_only__"] = True
model = cls.to(**value) model = cls.to(**value)
if to_register: if to_register:
cls.register_relation(model, child) cls.register_relation(model=model, child=child, relation_name=relation_name)
return model return model
@classmethod @classmethod
def _construct_model_from_pk( def _construct_model_from_pk(
cls, value: Any, child: "Model", to_register: bool cls, value: Any, child: "Model", to_register: bool, relation_name: str
) -> "Model": ) -> "Model":
""" """
Takes a pk value, creates a dummy instance and registers it for parent. Takes a pk value, creates a dummy instance and registers it for parent.
@ -263,11 +268,13 @@ class ForeignKeyField(BaseField):
) )
model = create_dummy_instance(fk=cls.to, pk=value) model = create_dummy_instance(fk=cls.to, pk=value)
if to_register: if to_register:
cls.register_relation(model, child) cls.register_relation(model=model, child=child, relation_name=relation_name)
return model return model
@classmethod @classmethod
def register_relation(cls, model: "Model", child: "Model") -> None: def register_relation(
cls, model: "Model", child: "Model", relation_name: str
) -> None:
""" """
Registers relation between parent and child in relation manager. Registers relation between parent and child in relation manager.
Relation manager is kep on each model (different instance). Relation manager is kep on each model (different instance).
@ -281,12 +288,20 @@ class ForeignKeyField(BaseField):
:type child: Model class :type child: Model class
""" """
model._orm.add( model._orm.add(
parent=model, child=child, child_name=cls.related_name, virtual=cls.virtual parent=model,
child=child,
child_name=cls.related_name or child.get_name() + "s",
virtual=cls.virtual,
relation_name=relation_name,
) )
@classmethod @classmethod
def expand_relationship( def expand_relationship(
cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True cls,
value: Any,
child: Union["Model", "NewBaseModel"],
to_register: bool = True,
relation_name: str = None,
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
""" """
For relations the child model is first constructed (if needed), For relations the child model is first constructed (if needed),
@ -316,5 +331,5 @@ class ForeignKeyField(BaseField):
model = constructors.get( # type: ignore model = constructors.get( # type: ignore
value.__class__.__name__, cls._construct_model_from_pk value.__class__.__name__, cls._construct_model_from_pk
)(value, child, to_register) )(value, child, to_register, relation_name)
return model return model

View File

@ -10,7 +10,6 @@ from typing import (
Tuple, Tuple,
Type, Type,
Union, Union,
cast,
) )
import databases import databases
@ -56,25 +55,27 @@ class ModelMeta:
abstract: bool abstract: bool
def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None: def register_relation_on_build_new(new_model: Type["Model"], field_name: str) -> None:
alias_manager.add_relation_type(field.to.Meta.tablename, table_name) alias_manager.add_relation_type_new(new_model, field_name)
def register_many_to_many_relation_on_build( def register_many_to_many_relation_on_build_new(
table_name: str, field: Type[ManyToManyField] new_model: Type["Model"], field: Type[ManyToManyField]
) -> None: ) -> None:
alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type_new(
alias_manager.add_relation_type( field.through, new_model.get_name(), is_multi=True
field.through.Meta.tablename, field.to.Meta.tablename )
alias_manager.add_relation_type_new(
field.through, field.to.get_name(), is_multi=True
) )
def reverse_field_not_already_registered( def reverse_field_not_already_registered(
child: Type["Model"], child_model_name: str, parent_model: Type["Model"] child: Type["Model"], child_model_name: str, parent_model: Type["Model"]
) -> bool: ) -> bool:
return ( return (
child_model_name not in parent_model.__fields__ child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__ and child.get_name() not in parent_model.__fields__
) )
@ -85,7 +86,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
parent_model = model_field.to parent_model = model_field.to
child = model child = model
if reverse_field_not_already_registered( if reverse_field_not_already_registered(
child, child_model_name, parent_model child, child_model_name, parent_model
): ):
register_reverse_model_fields( register_reverse_model_fields(
parent_model, child, child_model_name, model_field parent_model, child, child_model_name, model_field
@ -93,10 +94,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
def register_reverse_model_fields( def register_reverse_model_fields(
model: Type["Model"], model: Type["Model"],
child: Type["Model"], child: Type["Model"],
child_model_name: str, child_model_name: str,
model_field: Type["ForeignKeyField"], model_field: Type["ForeignKeyField"],
) -> None: ) -> None:
if issubclass(model_field, ManyToManyField): if issubclass(model_field, ManyToManyField):
model.Meta.model_fields[child_model_name] = ManyToMany( model.Meta.model_fields[child_model_name] = ManyToMany(
@ -111,7 +112,7 @@ def register_reverse_model_fields(
def adjust_through_many_to_many_model( def adjust_through_many_to_many_model(
model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField]
) -> None: ) -> None:
model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model_field.through.Meta.model_fields[model.get_name()] = ForeignKey(
model, real_name=model.get_name(), ondelete="CASCADE" model, real_name=model.get_name(), ondelete="CASCADE"
@ -128,7 +129,7 @@ def adjust_through_many_to_many_model(
def create_pydantic_field( def create_pydantic_field(
field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] field_name: str, model: Type["Model"], model_field: Type[ManyToManyField]
) -> None: ) -> None:
model_field.through.__fields__[field_name] = ModelField( model_field.through.__fields__[field_name] = ModelField(
name=field_name, name=field_name,
@ -150,7 +151,7 @@ def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField":
def create_and_append_m2m_fk( def create_and_append_m2m_fk(
model: Type["Model"], model_field: Type[ManyToManyField] model: Type["Model"], model_field: Type[ManyToManyField]
) -> None: ) -> None:
column = sqlalchemy.Column( column = sqlalchemy.Column(
model.get_name(), model.get_name(),
@ -166,7 +167,7 @@ def create_and_append_m2m_fk(
def check_pk_column_validity( def check_pk_column_validity(
field_name: str, field: BaseField, pkname: Optional[str] field_name: str, field: BaseField, pkname: Optional[str]
) -> Optional[str]: ) -> Optional[str]:
if pkname is not None: if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.") raise ModelDefinitionError("Only one primary key column is allowed.")
@ -175,8 +176,27 @@ def check_pk_column_validity(
return field_name return field_name
def validate_related_names_in_relations(
model_fields: Dict, new_model: Type["Model"]
) -> None:
already_registered: Dict[str, List[Optional[str]]] = dict()
for field in model_fields.values():
if issubclass(field, ForeignKeyField):
previous_related_names = already_registered.setdefault(field.to, [])
if field.related_name in previous_related_names:
raise ModelDefinitionError(
f"Multiple fields declared on {new_model.get_name(lower=False)} "
f"model leading to {field.to.get_name(lower=False)} model without "
f"related_name property set. \nThere can be only one relation with "
f"default/empty name: '{new_model.get_name() + 's'}'"
f"\nTip: provide different related_name for FK and/or M2M fields"
)
else:
previous_related_names.append(field.related_name)
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
model_fields: Dict, table_name: str model_fields: Dict, new_model: Type["Model"]
) -> Tuple[Optional[str], List[sqlalchemy.Column]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
columns = [] columns = []
pkname = None pkname = None
@ -186,30 +206,30 @@ def sqlalchemy_columns_from_model_fields(
"Table {table_name} had no fields so auto " "Table {table_name} had no fields so auto "
"Integer primary key named `id` created." "Integer primary key named `id` created."
) )
validate_related_names_in_relations(model_fields, new_model)
for field_name, field in model_fields.items(): for field_name, field in model_fields.items():
if field.primary_key: if field.primary_key:
pkname = check_pk_column_validity(field_name, field, pkname) pkname = check_pk_column_validity(field_name, field, pkname)
if ( if (
not field.pydantic_only not field.pydantic_only
and not field.virtual and not field.virtual
and not issubclass(field, ManyToManyField) and not issubclass(field, ManyToManyField)
): ):
columns.append(field.get_column(field.get_alias())) columns.append(field.get_column(field.get_alias()))
register_relation_in_alias_manager(table_name, field)
return pkname, columns return pkname, columns
def register_relation_in_alias_manager( def register_relation_in_alias_manager_new(
table_name: str, field: Type[ForeignKeyField] new_model: Type["Model"], field: Type[ForeignKeyField], field_name: str
) -> None: ) -> None:
if issubclass(field, ManyToManyField): if issubclass(field, ManyToManyField):
register_many_to_many_relation_on_build(table_name, field) register_many_to_many_relation_on_build_new(new_model=new_model, field=field)
elif issubclass(field, ForeignKeyField): elif issubclass(field, ForeignKeyField):
register_relation_on_build(table_name, field) register_relation_on_build_new(new_model=new_model, field_name=field_name)
def populate_default_pydantic_field_value( def populate_default_pydantic_field_value(
ormar_field: Type[BaseField], field_name: str, attrs: dict ormar_field: Type[BaseField], field_name: str, attrs: dict
) -> dict: ) -> dict:
curr_def_value = attrs.get(field_name, ormar.Undefined) curr_def_value = attrs.get(field_name, ormar.Undefined)
if lenient_issubclass(curr_def_value, ormar.fields.BaseField): if lenient_issubclass(curr_def_value, ormar.fields.BaseField):
@ -254,7 +274,7 @@ def extract_annotations_and_default_vals(attrs: dict) -> Tuple[Dict, Dict]:
def populate_meta_tablename_columns_and_pk( def populate_meta_tablename_columns_and_pk(
name: str, new_model: Type["Model"] name: str, new_model: Type["Model"]
) -> Type["Model"]: ) -> Type["Model"]:
tablename = name.lower() + "s" tablename = name.lower() + "s"
new_model.Meta.tablename = ( new_model.Meta.tablename = (
@ -267,7 +287,7 @@ def populate_meta_tablename_columns_and_pk(
pkname = new_model.Meta.pkname pkname = new_model.Meta.pkname
else: else:
pkname, columns = sqlalchemy_columns_from_model_fields( pkname, columns = sqlalchemy_columns_from_model_fields(
new_model.Meta.model_fields, new_model.Meta.tablename new_model.Meta.model_fields, new_model
) )
if pkname is None: if pkname is None:
@ -275,12 +295,11 @@ def populate_meta_tablename_columns_and_pk(
new_model.Meta.columns = columns new_model.Meta.columns = columns
new_model.Meta.pkname = pkname new_model.Meta.pkname = pkname
return new_model return new_model
def populate_meta_sqlalchemy_table_if_required( def populate_meta_sqlalchemy_table_if_required(
new_model: Type["Model"], new_model: Type["Model"],
) -> Type["Model"]: ) -> Type["Model"]:
""" """
Constructs sqlalchemy table out of columns and parameters set on Meta class. Constructs sqlalchemy table out of columns and parameters set on Meta class.
@ -371,7 +390,7 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001
def populate_default_options_values( def populate_default_options_values(
new_model: Type["Model"], model_fields: Dict new_model: Type["Model"], model_fields: Dict
) -> None: ) -> None:
""" """
Sets all optional Meta values to it's defaults Sets all optional Meta values to it's defaults
@ -445,15 +464,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa:
:param attrs: :param attrs:
:type attrs: Dict[str, str] :type attrs: Dict[str, str]
""" """
props = set()
for var_name, value in attrs.items():
if isinstance(value, property):
value = value.fget
field_config = getattr(value, "__property_field__", None)
if field_config:
props.add(var_name)
if meta_field_not_set(model=new_model, field_name="property_fields"): if meta_field_not_set(model=new_model, field_name="property_fields"):
props = set()
for var_name, value in attrs.items():
if isinstance(value, property):
value = value.fget
field_config = getattr(value, "__property_field__", None)
if field_config:
props.add(var_name)
new_model.Meta.property_fields = props new_model.Meta.property_fields = props
else:
new_model.Meta.property_fields = new_model.Meta.property_fields.union(props)
def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001
@ -490,11 +512,11 @@ def get_potential_fields(attrs: Dict) -> Dict:
def check_conflicting_fields( def check_conflicting_fields(
new_fields: Set, new_fields: Set,
attrs: Dict, attrs: Dict,
base_class: type, base_class: type,
curr_class: type, curr_class: type,
previous_fields: Set = None, previous_fields: Set = None,
) -> None: ) -> None:
""" """
You cannot redefine fields with same names in inherited classes. You cannot redefine fields with same names in inherited classes.
@ -524,11 +546,11 @@ def check_conflicting_fields(
def update_attrs_and_fields( def update_attrs_and_fields(
attrs: Dict, attrs: Dict,
new_attrs: Dict, new_attrs: Dict,
model_fields: Dict, model_fields: Dict,
new_model_fields: Dict, new_model_fields: Dict,
new_fields: Set, new_fields: Set,
) -> None: ) -> None:
""" """
Updates __annotations__, values of model fields (so pydantic FieldInfos) Updates __annotations__, values of model fields (so pydantic FieldInfos)
@ -551,7 +573,7 @@ def update_attrs_and_fields(
model_fields.update(new_model_fields) model_fields.update(new_model_fields)
def update_attrs_from_base_meta(base_class: "Model", attrs: Dict,) -> None: def update_attrs_from_base_meta(base_class: "Model", attrs: Dict, ) -> None:
""" """
Updates Meta parameters in child from parent if needed. Updates Meta parameters in child from parent if needed.
@ -560,33 +582,24 @@ def update_attrs_from_base_meta(base_class: "Model", attrs: Dict,) -> None:
:param attrs: new namespace for class being constructed :param attrs: new namespace for class being constructed
:type attrs: Dict :type attrs: Dict
""" """
params_to_update = ["metadata", "database", "constraints", "property_fields"] params_to_update = ["metadata", "database", "constraints"]
for param in params_to_update: for param in params_to_update:
if hasattr(base_class.Meta, param): current_value = attrs.get('Meta', {}).__dict__.get(param, ormar.Undefined)
if hasattr(attrs["Meta"], param): parent_value = base_class.Meta.__dict__.get(param) if hasattr(base_class, 'Meta') else None
curr_value = getattr(attrs["Meta"], param) if parent_value:
if isinstance(curr_value, list): if isinstance(current_value, list):
curr_value.extend(getattr(base_class.Meta, param)) current_value.extend(parent_value)
elif isinstance(curr_value, dict): # pragma: no cover
curr_value.update(getattr(base_class.Meta, param))
elif isinstance(curr_value, Set):
curr_value.union(getattr(base_class.Meta, param))
else:
# overwrite with child value if both set and its param / object
setattr(
attrs["Meta"], param, getattr(base_class.Meta, param)
) # pragma: no cover
else: else:
setattr(attrs["Meta"], param, getattr(base_class.Meta, param)) setattr(attrs["Meta"], param, parent_value)
def extract_mixin_fields_from_dict( def extract_from_parents_definition(
base_class: type, base_class: type,
curr_class: type, curr_class: type,
attrs: Dict, attrs: Dict,
model_fields: Dict[ model_fields: Dict[
str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]]
], ],
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
""" """
Extracts fields from base classes if they have valid oramr fields. Extracts fields from base classes if they have valid oramr fields.
@ -629,8 +642,11 @@ def extract_mixin_fields_from_dict(
f"{curr_class.__name__} cannot inherit " f"{curr_class.__name__} cannot inherit "
f"from non abstract class {base_class.__name__}" f"from non abstract class {base_class.__name__}"
) )
update_attrs_from_base_meta(base_class=base_class, attrs=attrs) # type: ignore update_attrs_from_base_meta(
model_fields.update(base_class.Meta.model_fields) base_class=base_class, # type: ignore
attrs=attrs,
)
model_fields.update(base_class.Meta.model_fields) # type: ignore
return attrs, model_fields return attrs, model_fields
key = "__annotations__" key = "__annotations__"
@ -687,14 +703,14 @@ def extract_mixin_fields_from_dict(
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__( # type: ignore def __new__( # type: ignore # noqa: CCR001
mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict
) -> "ModelMetaclass": ) -> "ModelMetaclass":
attrs["Config"] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
attrs["__name__"] = name attrs["__name__"] = name
attrs, model_fields = extract_annotations_and_default_vals(attrs) attrs, model_fields = extract_annotations_and_default_vals(attrs)
for ind, base in enumerate(reversed(bases)): for base in reversed(bases):
attrs, model_fields = extract_mixin_fields_from_dict( attrs, model_fields = extract_from_parents_definition(
base_class=base, curr_class=mcs, attrs=attrs, model_fields=model_fields base_class=base, curr_class=mcs, attrs=attrs, model_fields=model_fields
) )
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
@ -713,6 +729,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_tablename_columns_and_pk(name, new_model)
new_model = populate_meta_sqlalchemy_table_if_required(new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model)
expand_reverse_relationships(new_model) expand_reverse_relationships(new_model)
for field_name, field in new_model.Meta.model_fields.items():
register_relation_in_alias_manager_new(new_model, field, field_name)
if new_model.Meta.pkname not in attrs["__annotations__"]: if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname field_name = new_model.Meta.pkname

View File

@ -58,7 +58,8 @@ class Model(NewBaseModel):
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
select_related: List = None, select_related: List = None,
related_models: Any = None, related_models: Any = None,
previous_table: str = None, previous_model: Type[T] = None,
related_name: str = None,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> Optional[T]: ) -> Optional[T]:
@ -69,28 +70,32 @@ class Model(NewBaseModel):
if select_related: if select_related:
related_models = group_related_list(select_related) related_models = group_related_list(select_related)
if ( rel_name2 = related_name
previous_table
and previous_table in cls.Meta.model_fields
and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField)
):
previous_table = cls.Meta.model_fields[
previous_table
].through.Meta.tablename
if previous_table: if (
table_prefix = cls.Meta.alias_manager.resolve_relation_join( previous_model
previous_table, cls.Meta.table.name and related_name
and issubclass(
previous_model.Meta.model_fields[related_name], ManyToManyField
)
):
through_field = previous_model.Meta.model_fields[related_name]
rel_name2 = previous_model.resolve_relation_name(
through_field.through, through_field.to, explicit_multi=True
)
previous_model = through_field.through # type: ignore
if previous_model and rel_name2:
table_prefix = cls.Meta.alias_manager.resolve_relation_join_new(
previous_model, rel_name2
) )
else: else:
table_prefix = "" table_prefix = ""
previous_table = cls.Meta.table.name
item = cls.populate_nested_models_from_row( item = cls.populate_nested_models_from_row(
item=item, item=item,
row=row, row=row,
related_models=related_models, related_models=related_models,
previous_table=previous_table,
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
) )
@ -111,7 +116,6 @@ class Model(NewBaseModel):
instance.set_save_status(True) instance.set_save_status(True)
else: else:
instance = None instance = None
return instance return instance
@classmethod @classmethod
@ -120,7 +124,6 @@ class Model(NewBaseModel):
item: dict, item: dict,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
related_models: Any, related_models: Any,
previous_table: sqlalchemy.Table,
fields: Optional[Union[Dict, Set]] = None, fields: Optional[Union[Dict, Set]] = None,
exclude_fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None,
) -> dict: ) -> dict:
@ -135,7 +138,8 @@ class Model(NewBaseModel):
child = model_cls.from_row( child = model_cls.from_row(
row, row,
related_models=remainder, related_models=remainder,
previous_table=previous_table, previous_model=cls,
related_name=related,
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
) )
@ -146,7 +150,8 @@ class Model(NewBaseModel):
exclude_fields = cls.get_excluded(exclude_fields, related) exclude_fields = cls.get_excluded(exclude_fields, related)
child = model_cls.from_row( child = model_cls.from_row(
row, row,
previous_table=previous_table, previous_model=cls,
related_name=related,
fields=fields, fields=fields,
exclude_fields=exclude_fields, exclude_fields=exclude_fields,
) )

View File

@ -21,7 +21,7 @@ from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError
from ormar.queryset.utils import translate_list_to_dict, update from ormar.queryset.utils import translate_list_to_dict, update
import ormar # noqa: I100 import ormar # noqa: I100
from ormar.fields import BaseField from ormar.fields import BaseField, ManyToManyField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
@ -291,12 +291,21 @@ class ModelTableProxy:
"ModelTableProxy", "ModelTableProxy",
Type["ModelTableProxy"], Type["ModelTableProxy"],
], ],
explicit_multi: bool = False,
) -> str: ) -> str:
for name, field in item.Meta.model_fields.items(): for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField): # fastapi is creating clones of response model
# fastapi is creating clones of response model # that's why it can be a subclass of the original model
# that's why it can be a subclass of the original model # so we need to compare Meta too as this one is copied as is
# so we need to compare Meta too as this one is copied as is if issubclass(field, ManyToManyField):
attrib = "to" if not explicit_multi else "through"
if (
getattr(field, attrib) == related.__class__
or getattr(field, attrib).Meta == related.Meta
):
return name
elif issubclass(field, ForeignKeyField):
if field.to == related.__class__ or field.to.Meta == related.Meta: if field.to == related.__class__ or field.to.Meta == related.Meta:
return name return name

View File

@ -96,7 +96,7 @@ class NewBaseModel(
k: self._convert_json( k: self._convert_json(
k, k,
self.Meta.model_fields[k].expand_relationship( self.Meta.model_fields[k].expand_relationship(
v, self, to_register=False v, self, to_register=False, relation_name=k
), ),
"dumps", "dumps",
) )
@ -125,7 +125,7 @@ class NewBaseModel(
# register the columns models after initialization # register the columns models after initialization
for related in self.extract_related_names(): for related in self.extract_related_names():
self.Meta.model_fields[related].expand_relationship( self.Meta.model_fields[related].expand_relationship(
new_kwargs.get(related), self, to_register=True new_kwargs.get(related), self, to_register=True, relation_name=related
) )
def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001
@ -135,7 +135,9 @@ class NewBaseModel(
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
self.set_save_status(False) self.set_save_status(False)
elif name in self._orm: elif name in self._orm:
model = self.Meta.model_fields[name].expand_relationship(value, self) model = self.Meta.model_fields[name].expand_relationship(
value=value, child=self, relation_name=name
)
if isinstance(self.__dict__.get(name), list): if isinstance(self.__dict__.get(name), list):
# virtual foreign key or many to many # virtual foreign key or many to many
self.__dict__[name].append(model) self.__dict__[name].append(model)

View File

@ -131,17 +131,19 @@ class QueryClause:
# Walk the relationships to the actual model class # Walk the relationships to the actual model class
# against which the comparison is being made. # against which the comparison is being made.
previous_table = model_cls.Meta.tablename previous_model = model_cls
for part in related_parts: for part in related_parts:
part2 = part
if issubclass(model_cls.Meta.model_fields[part], ManyToManyField): if issubclass(model_cls.Meta.model_fields[part], ManyToManyField):
previous_table = model_cls.Meta.model_fields[ through_field = model_cls.Meta.model_fields[part]
part previous_model = through_field.through
].through.Meta.tablename part2 = model_cls.resolve_relation_name(
current_table = model_cls.Meta.model_fields[part].to.Meta.tablename through_field.through, through_field.to, explicit_multi=True
)
manager = model_cls.Meta.alias_manager manager = model_cls.Meta.alias_manager
table_prefix = manager.resolve_relation_join(previous_table, current_table) table_prefix = manager.resolve_relation_join_new(previous_model, part2)
model_cls = model_cls.Meta.model_fields[part].to model_cls = model_cls.Meta.model_fields[part].to
previous_table = current_table previous_model = model_cls
return select_related, table_prefix, model_cls return select_related, table_prefix, model_cls
def _compile_clause( def _compile_clause(

View File

@ -135,8 +135,8 @@ class SqlJoin:
model_cls = join_params.model_cls.Meta.model_fields[part].to model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name to_table = model_cls.Meta.table.name
alias = model_cls.Meta.alias_manager.resolve_relation_join( alias = model_cls.Meta.alias_manager.resolve_relation_join_new(
join_params.from_table, to_table join_params.prev_model, part
) )
if alias not in self.used_aliases: if alias not in self.used_aliases:
self._process_join( self._process_join(
@ -267,7 +267,9 @@ class SqlJoin:
model_cls, join_params.prev_model model_cls, join_params.prev_model
) )
to_key = model_cls.get_column_alias(to_field) to_key = model_cls.get_column_alias(to_field)
from_key = join_params.prev_model.get_column_alias(model_cls.Meta.pkname) from_key = join_params.prev_model.get_column_alias(
join_params.prev_model.Meta.pkname
)
else: else:
to_key = model_cls.get_column_alias(model_cls.Meta.pkname) to_key = model_cls.get_column_alias(model_cls.Meta.pkname)
from_key = join_params.prev_model.get_column_alias(part) from_key = join_params.prev_model.get_column_alias(part)

View File

@ -318,9 +318,8 @@ class PrefetchQuery:
if issubclass(target_field, ManyToManyField): if issubclass(target_field, ManyToManyField):
query_target = target_field.through query_target = target_field.through
select_related = [target_name] select_related = [target_name]
table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join_new(
from_table=query_target.Meta.tablename, query_target, target_name
to_table=target_field.to.Meta.tablename,
) )
self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix

View File

@ -1,11 +1,14 @@
import string import string
import uuid import uuid
from random import choices from random import choices
from typing import Dict, List from typing import Dict, List, TYPE_CHECKING, Type
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
if TYPE_CHECKING: # pragma: no cover
from ormar import Model
def get_table_alias() -> str: def get_table_alias() -> str:
alias = "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] alias = "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4]
@ -15,6 +18,7 @@ def get_table_alias() -> str:
class AliasManager: class AliasManager:
def __init__(self) -> None: def __init__(self) -> None:
self._aliases: Dict[str, str] = dict() self._aliases: Dict[str, str] = dict()
self._aliases_new: Dict[str, str] = dict()
@staticmethod @staticmethod
def prefixed_columns( def prefixed_columns(
@ -35,11 +39,25 @@ class AliasManager:
def prefixed_table_name(alias: str, name: str) -> text: def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}") return text(f"{name} {alias}_{name}")
def add_relation_type(self, to_table_name: str, table_name: str,) -> None: def add_relation_type_new(
if f"{table_name}_{to_table_name}" not in self._aliases: self, source_model: Type["Model"], relation_name: str, is_multi: bool = False
self._aliases[f"{table_name}_{to_table_name}"] = get_table_alias() ) -> None:
if f"{to_table_name}_{table_name}" not in self._aliases: parent_key = f"{source_model.get_name()}_{relation_name}"
self._aliases[f"{to_table_name}_{table_name}"] = get_table_alias() if parent_key not in self._aliases_new:
self._aliases_new[parent_key] = get_table_alias()
to_field = source_model.Meta.model_fields[relation_name]
child_model = to_field.to
related_name = to_field.related_name
if not related_name:
related_name = child_model.resolve_relation_name(
child_model, source_model, explicit_multi=is_multi
)
child_key = f"{child_model.get_name()}_{related_name}"
if child_key not in self._aliases_new:
self._aliases_new[child_key] = get_table_alias()
def resolve_relation_join(self, from_table: str, to_table: str) -> str: def resolve_relation_join_new(
return self._aliases.get(f"{from_table}_{to_table}", "") self, from_model: Type["Model"], relation_name: str
) -> str:
alias = self._aliases_new.get(f"{from_model.get_name()}_{relation_name}", "")
return alias

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import List, Optional, TYPE_CHECKING, Type, TypeVar, Union from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union
import ormar # noqa I100 import ormar # noqa I100
from ormar.exceptions import RelationshipInstanceError # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100
@ -31,6 +31,7 @@ class Relation:
self.manager = manager self.manager = manager
self._owner: "Model" = manager.owner self._owner: "Model" = manager.owner
self._type: RelationType = type_ self._type: RelationType = type_
self._to_remove: Set = set()
self.to: Type["T"] = to self.to: Type["T"] = to
self.through: Optional[Type["T"]] = through self.through: Optional[Type["T"]] = through
self.related_models: Optional[Union[RelationProxy, "T"]] = ( self.related_models: Optional[Union[RelationProxy, "T"]] = (
@ -39,17 +40,32 @@ class Relation:
else None else None
) )
def _clean_related(self) -> None:
cleaned_data = [
x
for i, x in enumerate(self.related_models) # type: ignore
if i not in self._to_remove
]
self.related_models = RelationProxy(
relation=self, type_=self._type, data_=cleaned_data
)
relation_name = self._owner.resolve_relation_name(self._owner, self.to)
self._owner.__dict__[relation_name] = cleaned_data
self._to_remove = set()
def _find_existing( def _find_existing(
self, child: Union["NewBaseModel", Type["NewBaseModel"]] self, child: Union["NewBaseModel", Type["NewBaseModel"]]
) -> Optional[int]: ) -> Optional[int]:
if not isinstance(self.related_models, RelationProxy): # pragma nocover if not isinstance(self.related_models, RelationProxy): # pragma nocover
raise ValueError("Cannot find existing models in parent relation type") raise ValueError("Cannot find existing models in parent relation type")
if self._to_remove:
self._clean_related()
for ind, relation_child in enumerate(self.related_models[:]): for ind, relation_child in enumerate(self.related_models[:]):
try: try:
if relation_child == child: if relation_child == child:
return ind return ind
except ReferenceError: # pragma no cover except ReferenceError: # pragma no cover
self.related_models.pop(ind) self._to_remove.add(ind)
return None return None
def add(self, child: "T") -> None: def add(self, child: "T") -> None:
@ -83,4 +99,6 @@ class Relation:
return self.related_models return self.related_models
def __repr__(self) -> str: # pragma no cover def __repr__(self) -> str: # pragma no cover
if self._to_remove:
self._clean_related()
return str(self.related_models) return str(self.related_models)

View File

@ -56,8 +56,14 @@ class RelationsManager:
return None return None
@staticmethod @staticmethod
def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: def add(
to_field: Type[BaseField] = child.resolve_relation_field(child, parent) parent: "Model",
child: "Model",
child_name: str,
virtual: bool,
relation_name: str,
) -> None:
to_field: Type[BaseField] = child.Meta.model_fields[relation_name]
(parent, child, child_name, to_name,) = get_relations_sides_and_names( (parent, child, child_name, to_name,) = get_relations_sides_and_names(
to_field, parent, child, child_name, virtual to_field, parent, child, child_name, virtual

View File

@ -11,8 +11,10 @@ if TYPE_CHECKING: # pragma no cover
class RelationProxy(list): class RelationProxy(list):
def __init__(self, relation: "Relation", type_: "RelationType") -> None: def __init__(
super().__init__() self, relation: "Relation", type_: "RelationType", data_: Any = None
) -> None:
super().__init__(data_ or ())
self.relation: "Relation" = relation self.relation: "Relation" = relation
self.type_: "RelationType" = type_ self.type_: "RelationType" = type_
self._owner: "Model" = self.relation.manager.owner self._owner: "Model" = self.relation.manager.owner

View File

@ -18,8 +18,11 @@ def get_relations_sides_and_names(
to_name = to_field.name to_name = to_field.name
if issubclass(to_field, ManyToManyField): if issubclass(to_field, ManyToManyField):
child_name, to_name = ( child_name, to_name = (
child.resolve_relation_name(parent, child), to_field.related_name
child.resolve_relation_name(child, parent), or child.resolve_relation_name(
parent, to_field.through, explicit_multi=True
),
to_name,
) )
child = proxy(child) child = proxy(child)
elif virtual: elif virtual:

View File

@ -8,7 +8,7 @@ import sqlalchemy as sa
from sqlalchemy import create_engine from sqlalchemy import create_engine
import ormar import ormar
from ormar import ModelDefinitionError from ormar import ModelDefinitionError, property_field
from ormar.exceptions import ModelError from ormar.exceptions import ModelError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -24,6 +24,10 @@ class AuditModel(ormar.Model):
created_by: str = ormar.String(max_length=100) created_by: str = ormar.String(max_length=100)
updated_by: str = ormar.String(max_length=100, default="Sam") updated_by: str = ormar.String(max_length=100, default="Sam")
@property_field
def audit(self): # pragma: no cover
return f"{self.created_by} {self.updated_by}"
class DateFieldsModelNoSubclass(ormar.Model): class DateFieldsModelNoSubclass(ormar.Model):
class Meta: class Meta:
@ -41,6 +45,7 @@ class DateFieldsModel(ormar.Model):
abstract = True abstract = True
metadata = metadata metadata = metadata
database = db database = db
constraints = [ormar.UniqueColumns("created_date", "updated_date")]
created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now)
updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now)
@ -49,11 +54,20 @@ class DateFieldsModel(ormar.Model):
class Category(DateFieldsModel, AuditModel): class Category(DateFieldsModel, AuditModel):
class Meta(ormar.ModelMeta): class Meta(ormar.ModelMeta):
tablename = "categories" tablename = "categories"
constraints = [ormar.UniqueColumns("name", "code")]
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50, unique=True, index=True) name: str = ormar.String(max_length=50, unique=True, index=True)
code: int = ormar.Integer() code: int = ormar.Integer()
@property_field
def code_name(self):
return f"{self.code}:{self.name}"
@property_field
def audit(self):
return f"{self.created_by} {self.updated_by}"
class Subject(DateFieldsModel): class Subject(DateFieldsModel):
class Meta(ormar.ModelMeta): class Meta(ormar.ModelMeta):
@ -99,6 +113,13 @@ def test_model_subclassing_non_abstract_raises_error():
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
def test_params_are_inherited():
assert Category.Meta.metadata == metadata
assert Category.Meta.database == db
assert len(Category.Meta.constraints) == 2
assert len(Category.Meta.property_fields) == 2
def round_date_to_seconds( def round_date_to_seconds(
date: datetime.datetime, date: datetime.datetime,
) -> datetime.datetime: # pragma: no cover ) -> datetime.datetime: # pragma: no cover
@ -132,7 +153,9 @@ async def test_fields_inherited_from_mixin():
inspector = sa.inspect(engine) inspector = sa.inspect(engine)
assert "categories" in inspector.get_table_names() assert "categories" in inspector.get_table_names()
table_columns = [x.get("name") for x in inspector.get_columns("categories")] table_columns = [x.get("name") for x in inspector.get_columns("categories")]
assert all(col in table_columns for col in mixin_columns) # + mixin2_columns) assert all(
col in table_columns for col in mixin_columns
) # + mixin2_columns)
assert "subjects" in inspector.get_table_names() assert "subjects" in inspector.get_table_names()
table_columns = [x.get("name") for x in inspector.get_columns("subjects")] table_columns = [x.get("name") for x in inspector.get_columns("subjects")]

View File

@ -7,7 +7,7 @@ from fastapi import FastAPI
from starlette.testclient import TestClient from starlette.testclient import TestClient
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
from tests.test_inheritance_concrete import Category, Subject, metadata from tests.test_inheritance_concrete import Category, Subject, metadata # type: ignore
app = FastAPI() app = FastAPI()
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -53,25 +53,25 @@ def test_read_main():
test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max") test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max")
test_subject = dict(name="Bar") test_subject = dict(name="Bar")
response = client.post( response = client.post("/categories/", json=test_category)
"/categories/", json=test_category
)
assert response.status_code == 200 assert response.status_code == 200
cat = Category(**response.json()) cat = Category(**response.json())
assert cat.name == 'Foo' assert cat.name == "Foo"
assert cat.created_by == 'Sam' assert cat.created_by == "Sam"
assert cat.created_date is not None assert cat.created_date is not None
assert cat.id == 1 assert cat.id == 1
cat_dict = cat.dict() cat_dict = cat.dict()
cat_dict['updated_date'] = cat_dict['updated_date'].strftime("%Y-%m-%d %H:%M:%S.%f") cat_dict["updated_date"] = cat_dict["updated_date"].strftime(
cat_dict['created_date'] = cat_dict['created_date'].strftime("%Y-%m-%d %H:%M:%S.%f") "%Y-%m-%d %H:%M:%S.%f"
test_subject['category'] = cat_dict
response = client.post(
"/subjects/", json=test_subject
) )
cat_dict["created_date"] = cat_dict["created_date"].strftime(
"%Y-%m-%d %H:%M:%S.%f"
)
test_subject["category"] = cat_dict
response = client.post("/subjects/", json=test_subject)
assert response.status_code == 200 assert response.status_code == 200
sub = Subject(**response.json()) sub = Subject(**response.json())
assert sub.name == 'Bar' assert sub.name == "Bar"
assert sub.category.pk == cat.pk assert sub.category.pk == cat.pk
assert isinstance(sub.updated_date, datetime.datetime) assert isinstance(sub.updated_date, datetime.datetime)

View File

@ -7,7 +7,7 @@ from fastapi import FastAPI
from starlette.testclient import TestClient from starlette.testclient import TestClient
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
from tests.test_inheritance_mixins import Category, Subject, metadata from tests.test_inheritance_mixins import Category, Subject, metadata # type: ignore
app = FastAPI() app = FastAPI()
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -53,25 +53,25 @@ def test_read_main():
test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max") test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max")
test_subject = dict(name="Bar") test_subject = dict(name="Bar")
response = client.post( response = client.post("/categories/", json=test_category)
"/categories/", json=test_category
)
assert response.status_code == 200 assert response.status_code == 200
cat = Category(**response.json()) cat = Category(**response.json())
assert cat.name == 'Foo' assert cat.name == "Foo"
assert cat.created_by == 'Sam' assert cat.created_by == "Sam"
assert cat.created_date is not None assert cat.created_date is not None
assert cat.id == 1 assert cat.id == 1
cat_dict = cat.dict() cat_dict = cat.dict()
cat_dict['updated_date'] = cat_dict['updated_date'].strftime("%Y-%m-%d %H:%M:%S.%f") cat_dict["updated_date"] = cat_dict["updated_date"].strftime(
cat_dict['created_date'] = cat_dict['created_date'].strftime("%Y-%m-%d %H:%M:%S.%f") "%Y-%m-%d %H:%M:%S.%f"
test_subject['category'] = cat_dict
response = client.post(
"/subjects/", json=test_subject
) )
cat_dict["created_date"] = cat_dict["created_date"].strftime(
"%Y-%m-%d %H:%M:%S.%f"
)
test_subject["category"] = cat_dict
response = client.post("/subjects/", json=test_subject)
assert response.status_code == 200 assert response.status_code == 200
sub = Subject(**response.json()) sub = Subject(**response.json())
assert sub.name == 'Bar' assert sub.name == "Bar"
assert sub.category.pk == cat.pk assert sub.category.pk == cat.pk
assert isinstance(sub.updated_date, datetime.datetime) assert isinstance(sub.updated_date, datetime.datetime)

View File

@ -1,115 +0,0 @@
import asyncio
from datetime import date
from typing import List, Optional, Union
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class MainMeta(ormar.ModelMeta):
metadata = metadata
database = database
class Role(ormar.Model):
class Meta(MainMeta):
pass
name: str = ormar.Text(primary_key=True)
order: int = ormar.Integer(default=0)
description: str = ormar.Text()
class Company(ormar.Model):
class Meta(MainMeta):
pass
name: str = ormar.Text(primary_key=True)
class UserRoleCompany(ormar.Model):
class Meta(MainMeta):
pass
class User(ormar.Model):
class Meta(MainMeta):
pass
registrationnumber: str = ormar.Text(primary_key=True)
company: Company = ormar.ForeignKey(Company)
name: str = ormar.Text()
role: Optional[Role] = ormar.ForeignKey(Role)
roleforcompanies: Optional[Union[Company, List[Company]]] = ormar.ManyToMany(Company, through=UserRoleCompany)
lastupdate: date = ormar.DateTime(server_default=sqlalchemy.func.now())
@pytest.mark.asyncio
async def test_create_primary_models():
async with database:
print("adding role")
role_0 = await Role.objects.create(name="user", order=0, description="no administration right")
role_1 = await Role.objects.create(name="admin", order=1, description="standard administration right")
role_2 = await Role.objects.create(name="super_admin", order=2, description="super administration right")
assert await Role.objects.count() == 3
print("adding company")
company_0 = await Company.objects.create(name="Company")
company_1 = await Company.objects.create(name="Subsidiary Company 1")
company_2 = await Company.objects.create(name="Subsidiary Company 2")
company_3 = await Company.objects.create(name="Subsidiary Company 3")
assert await Company.objects.count() == 4
print("adding user")
user = await User.objects.create(registrationnumber="00-00000", company=company_0, name="admin", role=role_1)
assert await User.objects.count() == 1
print("removing user")
await user.delete()
assert await User.objects.count() == 0
print("adding user with company-role")
companies: List[Company] = [company_1, company_2]
# user = await User.objects.create(registrationnumber="00-00000", company=company_0, name="admin", role=role_1, roleforcompanies=companies)
user = await User.objects.create(registrationnumber="00-00000", company=company_0, name="admin", role=role_1)
# print(User.__fields__)
await user.roleforcompanies.add(company_1)
await user.roleforcompanies.add(company_2)
users = await User.objects.select_related("roleforcompanies").all()
# print(jsonpickle.encode(jsonable_encoder(users), unpicklable=False, keys=True))
"""
This is the request generated:
'SELECT
users.registrationnumber as registrationnumber,
users.company as company,
users.name as name, users.role as role,
users.lastupdate as lastupdate,
cy24b4_userrolecompanys.id as cy24b4_id,
cy24b4_userrolecompanys.company as cy24b4_company,
cy24b4_userrolecompanys.user as cy24b4_user,
jn50a4_companys.name as jn50a4_name \n
FROM users
LEFT OUTER JOIN userrolecompanys cy24b4_userrolecompanys ON cy24b4_userrolecompanys.user=users.id
LEFT OUTER JOIN companys jn50a4_companys ON jn50a4_companys.name=cy24b4_userrolecompanys.company
ORDER BY users.registrationnumber, jn50a4_companys.name'
There is an error in the First LEFT OUTER JOIN generated:
... companys.user=users.id
should be:
... companys.user=users.registrationnumber
There is also a \n in the midle of the string...
The execution produce the error: column users.id does not exist
"""

View File

@ -102,7 +102,7 @@ async def test_model_multiple_instances_of_same_table_in_schema():
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
await create_data() await create_data()
classes = await SchoolClass.objects.select_related( classes = await SchoolClass.objects.select_related(
["teachers__category", "students"] ["teachers__category", "students__schoolclass"]
).all() ).all()
assert classes[0].name == "Math" assert classes[0].name == "Math"
assert classes[0].students[0].name == "Jane" assert classes[0].students[0].name == "Jane"

View File

@ -56,8 +56,7 @@ class SecondaryModel(ormar.Model):
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100) name: str = ormar.String(max_length=100)
primary_model: PrimaryModel = ormar.ForeignKey( primary_model: PrimaryModel = ormar.ForeignKey(
PrimaryModel, PrimaryModel, related_name="secondary_models",
related_name="secondary_models",
) )
@ -74,7 +73,8 @@ async def test_create_primary_models():
("Primary 7", "Some text 7", "Some other text 7"), ("Primary 7", "Some text 7", "Some other text 7"),
("Primary 8", "Some text 8", "Some other text 8"), ("Primary 8", "Some text 8", "Some other text 8"),
("Primary 9", "Some text 9", "Some other text 9"), ("Primary 9", "Some text 9", "Some other text 9"),
("Primary 10", "Some text 10", "Some other text 10")]: ("Primary 10", "Some text 10", "Some other text 10"),
]:
await PrimaryModel( await PrimaryModel(
name=name, some_text=some_text, some_other_text=some_other_text name=name, some_text=some_text, some_other_text=some_other_text
).save() ).save()

View File

@ -0,0 +1,130 @@
# type: ignore
from datetime import date
from typing import List, Optional, Union
import databases
import pytest
import sqlalchemy
from sqlalchemy import create_engine
import ormar
from ormar import ModelDefinitionError
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class MainMeta(ormar.ModelMeta):
metadata = metadata
database = database
class Role(ormar.Model):
class Meta(MainMeta):
pass
name: str = ormar.String(primary_key=True, max_length=1000)
order: int = ormar.Integer(default=0, name="sort_order")
description: str = ormar.Text()
class Company(ormar.Model):
class Meta(MainMeta):
pass
name: str = ormar.String(primary_key=True, max_length=1000)
class UserRoleCompany(ormar.Model):
class Meta(MainMeta):
pass
class User(ormar.Model):
class Meta(MainMeta):
pass
registrationnumber: str = ormar.String(primary_key=True, max_length=1000)
company: Company = ormar.ForeignKey(Company)
company2: Company = ormar.ForeignKey(Company, related_name="secondary_users")
name: str = ormar.Text()
role: Optional[Role] = ormar.ForeignKey(Role)
roleforcompanies: Optional[Union[Company, List[Company]]] = ormar.ManyToMany(
Company, through=UserRoleCompany, related_name="role_users"
)
lastupdate: date = ormar.DateTime(server_default=sqlalchemy.func.now())
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_wrong_model():
with pytest.raises(ModelDefinitionError):
class User(ormar.Model):
class Meta(MainMeta):
pass
registrationnumber: str = ormar.Text(primary_key=True)
company: Company = ormar.ForeignKey(Company)
company2: Company = ormar.ForeignKey(Company)
@pytest.mark.asyncio
async def test_create_primary_models():
async with database:
await Role.objects.create(
name="user", order=0, description="no administration right"
)
role_1 = await Role.objects.create(
name="admin", order=1, description="standard administration right"
)
await Role.objects.create(
name="super_admin", order=2, description="super administration right"
)
assert await Role.objects.count() == 3
company_0 = await Company.objects.create(name="Company")
company_1 = await Company.objects.create(name="Subsidiary Company 1")
company_2 = await Company.objects.create(name="Subsidiary Company 2")
company_3 = await Company.objects.create(name="Subsidiary Company 3")
assert await Company.objects.count() == 4
user = await User.objects.create(
registrationnumber="00-00000", company=company_0, name="admin", role=role_1
)
assert await User.objects.count() == 1
await user.delete()
assert await User.objects.count() == 0
user = await User.objects.create(
registrationnumber="00-00000",
company=company_0,
company2=company_3,
name="admin",
role=role_1,
)
await user.roleforcompanies.add(company_1)
await user.roleforcompanies.add(company_2)
users = await User.objects.select_related(
["company", "company2", "roleforcompanies"]
).all()
assert len(users) == 1
assert len(users[0].roleforcompanies) == 2
assert len(users[0].roleforcompanies[0].role_users) == 1
assert users[0].company.name == "Company"
assert len(users[0].company.users) == 1
assert users[0].company2.name == "Subsidiary Company 3"
assert len(users[0].company2.secondary_users) == 1
users = await User.objects.select_related("roleforcompanies").all()
assert len(users) == 1
assert len(users[0].roleforcompanies) == 2

View File

@ -0,0 +1,93 @@
from typing import List, Optional
import databases
import pytest
import sqlalchemy
from sqlalchemy import create_engine
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class User(ormar.Model):
class Meta:
metadata = metadata
database = database
tablename = "test_users"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=50)
class Signup(ormar.Model):
class Meta:
metadata = metadata
database = database
tablename = "test_signup"
id: int = ormar.Integer(primary_key=True)
class Session(ormar.Model):
class Meta:
metadata = metadata
database = database
tablename = "test_sessions"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=255, index=True)
some_text: str = ormar.Text()
some_other_text: Optional[str] = ormar.Text(nullable=True)
students: Optional[List[User]] = ormar.ManyToMany(User, through=Signup)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_list_sessions_for_user():
async with database:
for user_id in [1, 2, 3, 4, 5]:
await User.objects.create(name=f"User {user_id}")
for name, some_text, some_other_text in [
("Session 1", "Some text 1", "Some other text 1"),
("Session 2", "Some text 2", "Some other text 2"),
("Session 3", "Some text 3", "Some other text 3"),
("Session 4", "Some text 4", "Some other text 4"),
("Session 5", "Some text 5", "Some other text 5"),
]:
await Session(
name=name, some_text=some_text, some_other_text=some_other_text
).save()
s1 = await Session.objects.get(pk=1)
s2 = await Session.objects.get(pk=2)
users = {}
for i in range(1, 6):
user = await User.objects.get(pk=i)
users[f"user_{i}"] = user
if i % 2 == 0:
await s1.students.add(user)
else:
await s2.students.add(user)
assert len(s1.students) == 2
assert len(s2.students) == 3
assert [x.pk for x in s1.students] == [2, 4]
assert [x.pk for x in s2.students] == [1, 3, 5]
user = await User.objects.select_related("sessions").get(pk=1)
assert user.sessions is not None
assert len(user.sessions) > 0