diff --git a/docs/releases.md b/docs/releases.md index b566795..97cbf8e 100644 --- a/docs/releases.md +++ b/docs/releases.md @@ -1,3 +1,11 @@ +# 0.7.4 + +* Allow multiple relations to the same related model/table. +* Fix for wrong relation column used in many_to_many relation joins (fix [#73][#73]) +* Fix for wrong relation population for m2m relations when also fk relation present for same model. +* Add check if user provide related_name if there are multiple relations to same table on one model. +* More eager cleaning of the dead weak proxy models. + # 0.7.3 * Fix for setting fetching related model with UUDI pk, which is a string in raw (fix [#71][#71]) @@ -193,4 +201,5 @@ Add queryset level methods [#60]: https://github.com/collerek/ormar/issues/60 [#68]: https://github.com/collerek/ormar/issues/68 [#70]: https://github.com/collerek/ormar/issues/70 -[#71]: https://github.com/collerek/ormar/issues/71 \ No newline at end of file +[#71]: https://github.com/collerek/ormar/issues/71 +[#73]: https://github.com/collerek/ormar/issues/73 \ No newline at end of file diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 343e7f9..db9ae57 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -25,6 +25,7 @@ class BaseField(FieldInfo): """ __type__ = None + related_name = None column_type: sqlalchemy.Column constraints: List = [] @@ -222,7 +223,11 @@ class BaseField(FieldInfo): @classmethod def expand_relationship( - cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True + cls, + value: Any, + child: Union["Model", "NewBaseModel"], + to_register: bool = True, + relation_name: str = None, ) -> Any: """ Function overwritten for relations, in basic field the value is returned as is. diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index ade32f4..e7c27c7 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -162,7 +162,7 @@ class ForeignKeyField(BaseField): @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model", to_register: bool + cls, value: List, child: "Model", to_register: bool, relation_name: str ) -> List["Model"]: """ Takes a list of Models and registers them on parent. @@ -180,13 +180,18 @@ class ForeignKeyField(BaseField): :rtype: List["Model"] """ return [ - cls.expand_relationship(val, child, to_register) # type: ignore + cls.expand_relationship( # type: ignore + value=val, + child=child, + to_register=to_register, + relation_name=relation_name, + ) for val in value ] @classmethod def _register_existing_model( - cls, value: "Model", child: "Model", to_register: bool + cls, value: "Model", child: "Model", to_register: bool, relation_name: str ) -> "Model": """ Takes already created instance and registers it for parent. @@ -204,12 +209,12 @@ class ForeignKeyField(BaseField): :rtype: Model """ if to_register: - cls.register_relation(value, child) + cls.register_relation(model=value, child=child, relation_name=relation_name) return value @classmethod def _construct_model_from_dict( - cls, value: dict, child: "Model", to_register: bool + cls, value: dict, child: "Model", to_register: bool, relation_name: str ) -> "Model": """ Takes a dictionary, creates a instance and registers it for parent. @@ -231,12 +236,12 @@ class ForeignKeyField(BaseField): value["__pk_only__"] = True model = cls.to(**value) if to_register: - cls.register_relation(model, child) + cls.register_relation(model=model, child=child, relation_name=relation_name) return model @classmethod def _construct_model_from_pk( - cls, value: Any, child: "Model", to_register: bool + cls, value: Any, child: "Model", to_register: bool, relation_name: str ) -> "Model": """ Takes a pk value, creates a dummy instance and registers it for parent. @@ -263,11 +268,13 @@ class ForeignKeyField(BaseField): ) model = create_dummy_instance(fk=cls.to, pk=value) if to_register: - cls.register_relation(model, child) + cls.register_relation(model=model, child=child, relation_name=relation_name) return model @classmethod - def register_relation(cls, model: "Model", child: "Model") -> None: + def register_relation( + cls, model: "Model", child: "Model", relation_name: str + ) -> None: """ Registers relation between parent and child in relation manager. Relation manager is kep on each model (different instance). @@ -281,12 +288,20 @@ class ForeignKeyField(BaseField): :type child: Model class """ model._orm.add( - parent=model, child=child, child_name=cls.related_name, virtual=cls.virtual + parent=model, + child=child, + child_name=cls.related_name or child.get_name() + "s", + virtual=cls.virtual, + relation_name=relation_name, ) @classmethod def expand_relationship( - cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True + cls, + value: Any, + child: Union["Model", "NewBaseModel"], + to_register: bool = True, + relation_name: str = None, ) -> Optional[Union["Model", List["Model"]]]: """ For relations the child model is first constructed (if needed), @@ -316,5 +331,5 @@ class ForeignKeyField(BaseField): model = constructors.get( # type: ignore value.__class__.__name__, cls._construct_model_from_pk - )(value, child, to_register) + )(value, child, to_register, relation_name) return model diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index d6fded6..eac5c2b 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -10,7 +10,6 @@ from typing import ( Tuple, Type, Union, - cast, ) import databases @@ -56,25 +55,27 @@ class ModelMeta: abstract: bool -def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None: - alias_manager.add_relation_type(field.to.Meta.tablename, table_name) +def register_relation_on_build_new(new_model: Type["Model"], field_name: str) -> None: + alias_manager.add_relation_type_new(new_model, field_name) -def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] +def register_many_to_many_relation_on_build_new( + new_model: Type["Model"], field: Type[ManyToManyField] ) -> None: - alias_manager.add_relation_type(field.through.Meta.tablename, table_name) - alias_manager.add_relation_type( - field.through.Meta.tablename, field.to.Meta.tablename + alias_manager.add_relation_type_new( + field.through, new_model.get_name(), is_multi=True + ) + alias_manager.add_relation_type_new( + field.through, field.to.get_name(), is_multi=True ) def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -85,7 +86,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -93,10 +94,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -111,7 +112,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, real_name=model.get_name(), ondelete="CASCADE" @@ -128,7 +129,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -150,7 +151,7 @@ def get_pydantic_field(field_name: str, model: Type["Model"]) -> "ModelField": def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -166,7 +167,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -175,8 +176,27 @@ def check_pk_column_validity( return field_name +def validate_related_names_in_relations( + model_fields: Dict, new_model: Type["Model"] +) -> None: + already_registered: Dict[str, List[Optional[str]]] = dict() + for field in model_fields.values(): + if issubclass(field, ForeignKeyField): + previous_related_names = already_registered.setdefault(field.to, []) + if field.related_name in previous_related_names: + raise ModelDefinitionError( + f"Multiple fields declared on {new_model.get_name(lower=False)} " + f"model leading to {field.to.get_name(lower=False)} model without " + f"related_name property set. \nThere can be only one relation with " + f"default/empty name: '{new_model.get_name() + 's'}'" + f"\nTip: provide different related_name for FK and/or M2M fields" + ) + else: + previous_related_names.append(field.related_name) + + def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, new_model: Type["Model"] ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -186,30 +206,30 @@ def sqlalchemy_columns_from_model_fields( "Table {table_name} had no fields so auto " "Integer primary key named `id` created." ) + validate_related_names_in_relations(model_fields, new_model) for field_name, field in model_fields.items(): if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field.get_alias())) - register_relation_in_alias_manager(table_name, field) return pkname, columns -def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] +def register_relation_in_alias_manager_new( + new_model: Type["Model"], field: Type[ForeignKeyField], field_name: str ) -> None: if issubclass(field, ManyToManyField): - register_many_to_many_relation_on_build(table_name, field) + register_many_to_many_relation_on_build_new(new_model=new_model, field=field) elif issubclass(field, ForeignKeyField): - register_relation_on_build(table_name, field) + register_relation_on_build_new(new_model=new_model, field_name=field_name) def populate_default_pydantic_field_value( - ormar_field: Type[BaseField], field_name: str, attrs: dict + ormar_field: Type[BaseField], field_name: str, attrs: dict ) -> dict: curr_def_value = attrs.get(field_name, ormar.Undefined) if lenient_issubclass(curr_def_value, ormar.fields.BaseField): @@ -254,7 +274,7 @@ def extract_annotations_and_default_vals(attrs: dict) -> Tuple[Dict, Dict]: def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = ( @@ -267,7 +287,7 @@ def populate_meta_tablename_columns_and_pk( pkname = new_model.Meta.pkname else: pkname, columns = sqlalchemy_columns_from_model_fields( - new_model.Meta.model_fields, new_model.Meta.tablename + new_model.Meta.model_fields, new_model ) if pkname is None: @@ -275,12 +295,11 @@ def populate_meta_tablename_columns_and_pk( new_model.Meta.columns = columns new_model.Meta.pkname = pkname - return new_model def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: """ Constructs sqlalchemy table out of columns and parameters set on Meta class. @@ -371,7 +390,7 @@ def populate_choices_validators(model: Type["Model"]) -> None: # noqa CCR001 def populate_default_options_values( - new_model: Type["Model"], model_fields: Dict + new_model: Type["Model"], model_fields: Dict ) -> None: """ Sets all optional Meta values to it's defaults @@ -445,15 +464,18 @@ def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: :param attrs: :type attrs: Dict[str, str] """ + props = set() + for var_name, value in attrs.items(): + if isinstance(value, property): + value = value.fget + field_config = getattr(value, "__property_field__", None) + if field_config: + props.add(var_name) + if meta_field_not_set(model=new_model, field_name="property_fields"): - props = set() - for var_name, value in attrs.items(): - if isinstance(value, property): - value = value.fget - field_config = getattr(value, "__property_field__", None) - if field_config: - props.add(var_name) new_model.Meta.property_fields = props + else: + new_model.Meta.property_fields = new_model.Meta.property_fields.union(props) def register_signals(new_model: Type["Model"]) -> None: # noqa: CCR001 @@ -490,11 +512,11 @@ def get_potential_fields(attrs: Dict) -> Dict: def check_conflicting_fields( - new_fields: Set, - attrs: Dict, - base_class: type, - curr_class: type, - previous_fields: Set = None, + new_fields: Set, + attrs: Dict, + base_class: type, + curr_class: type, + previous_fields: Set = None, ) -> None: """ You cannot redefine fields with same names in inherited classes. @@ -524,11 +546,11 @@ def check_conflicting_fields( def update_attrs_and_fields( - attrs: Dict, - new_attrs: Dict, - model_fields: Dict, - new_model_fields: Dict, - new_fields: Set, + attrs: Dict, + new_attrs: Dict, + model_fields: Dict, + new_model_fields: Dict, + new_fields: Set, ) -> None: """ Updates __annotations__, values of model fields (so pydantic FieldInfos) @@ -551,7 +573,7 @@ def update_attrs_and_fields( model_fields.update(new_model_fields) -def update_attrs_from_base_meta(base_class: "Model", attrs: Dict,) -> None: +def update_attrs_from_base_meta(base_class: "Model", attrs: Dict, ) -> None: """ Updates Meta parameters in child from parent if needed. @@ -560,33 +582,24 @@ def update_attrs_from_base_meta(base_class: "Model", attrs: Dict,) -> None: :param attrs: new namespace for class being constructed :type attrs: Dict """ - params_to_update = ["metadata", "database", "constraints", "property_fields"] + params_to_update = ["metadata", "database", "constraints"] for param in params_to_update: - if hasattr(base_class.Meta, param): - if hasattr(attrs["Meta"], param): - curr_value = getattr(attrs["Meta"], param) - if isinstance(curr_value, list): - curr_value.extend(getattr(base_class.Meta, param)) - elif isinstance(curr_value, dict): # pragma: no cover - curr_value.update(getattr(base_class.Meta, param)) - elif isinstance(curr_value, Set): - curr_value.union(getattr(base_class.Meta, param)) - else: - # overwrite with child value if both set and its param / object - setattr( - attrs["Meta"], param, getattr(base_class.Meta, param) - ) # pragma: no cover + current_value = attrs.get('Meta', {}).__dict__.get(param, ormar.Undefined) + parent_value = base_class.Meta.__dict__.get(param) if hasattr(base_class, 'Meta') else None + if parent_value: + if isinstance(current_value, list): + current_value.extend(parent_value) else: - setattr(attrs["Meta"], param, getattr(base_class.Meta, param)) + setattr(attrs["Meta"], param, parent_value) -def extract_mixin_fields_from_dict( - base_class: type, - curr_class: type, - attrs: Dict, - model_fields: Dict[ - str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] - ], +def extract_from_parents_definition( + base_class: type, + curr_class: type, + attrs: Dict, + model_fields: Dict[ + str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] + ], ) -> Tuple[Dict, Dict]: """ Extracts fields from base classes if they have valid oramr fields. @@ -629,8 +642,11 @@ def extract_mixin_fields_from_dict( f"{curr_class.__name__} cannot inherit " f"from non abstract class {base_class.__name__}" ) - update_attrs_from_base_meta(base_class=base_class, attrs=attrs) # type: ignore - model_fields.update(base_class.Meta.model_fields) + update_attrs_from_base_meta( + base_class=base_class, # type: ignore + attrs=attrs, + ) + model_fields.update(base_class.Meta.model_fields) # type: ignore return attrs, model_fields key = "__annotations__" @@ -687,14 +703,14 @@ def extract_mixin_fields_from_dict( class ModelMetaclass(pydantic.main.ModelMetaclass): - def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + def __new__( # type: ignore # noqa: CCR001 + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name attrs, model_fields = extract_annotations_and_default_vals(attrs) - for ind, base in enumerate(reversed(bases)): - attrs, model_fields = extract_mixin_fields_from_dict( + for base in reversed(bases): + attrs, model_fields = extract_from_parents_definition( base_class=base, curr_class=mcs, attrs=attrs, model_fields=model_fields ) new_model = super().__new__( # type: ignore @@ -713,6 +729,8 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): new_model = populate_meta_tablename_columns_and_pk(name, new_model) new_model = populate_meta_sqlalchemy_table_if_required(new_model) expand_reverse_relationships(new_model) + for field_name, field in new_model.Meta.model_fields.items(): + register_relation_in_alias_manager_new(new_model, field, field_name) if new_model.Meta.pkname not in attrs["__annotations__"]: field_name = new_model.Meta.pkname diff --git a/ormar/models/model.py b/ormar/models/model.py index 8e7a8f5..12b5d80 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -58,7 +58,8 @@ class Model(NewBaseModel): row: sqlalchemy.engine.ResultProxy, select_related: List = None, related_models: Any = None, - previous_table: str = None, + previous_model: Type[T] = None, + related_name: str = None, fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None, ) -> Optional[T]: @@ -69,28 +70,32 @@ class Model(NewBaseModel): if select_related: related_models = group_related_list(select_related) - if ( - previous_table - and previous_table in cls.Meta.model_fields - and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) - ): - previous_table = cls.Meta.model_fields[ - previous_table - ].through.Meta.tablename + rel_name2 = related_name - if previous_table: - table_prefix = cls.Meta.alias_manager.resolve_relation_join( - previous_table, cls.Meta.table.name + if ( + previous_model + and related_name + and issubclass( + previous_model.Meta.model_fields[related_name], ManyToManyField + ) + ): + through_field = previous_model.Meta.model_fields[related_name] + rel_name2 = previous_model.resolve_relation_name( + through_field.through, through_field.to, explicit_multi=True + ) + previous_model = through_field.through # type: ignore + + if previous_model and rel_name2: + table_prefix = cls.Meta.alias_manager.resolve_relation_join_new( + previous_model, rel_name2 ) else: table_prefix = "" - previous_table = cls.Meta.table.name item = cls.populate_nested_models_from_row( item=item, row=row, related_models=related_models, - previous_table=previous_table, fields=fields, exclude_fields=exclude_fields, ) @@ -111,7 +116,6 @@ class Model(NewBaseModel): instance.set_save_status(True) else: instance = None - return instance @classmethod @@ -120,7 +124,6 @@ class Model(NewBaseModel): item: dict, row: sqlalchemy.engine.ResultProxy, related_models: Any, - previous_table: sqlalchemy.Table, fields: Optional[Union[Dict, Set]] = None, exclude_fields: Optional[Union[Dict, Set]] = None, ) -> dict: @@ -135,7 +138,8 @@ class Model(NewBaseModel): child = model_cls.from_row( row, related_models=remainder, - previous_table=previous_table, + previous_model=cls, + related_name=related, fields=fields, exclude_fields=exclude_fields, ) @@ -146,7 +150,8 @@ class Model(NewBaseModel): exclude_fields = cls.get_excluded(exclude_fields, related) child = model_cls.from_row( row, - previous_table=previous_table, + previous_model=cls, + related_name=related, fields=fields, exclude_fields=exclude_fields, ) diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index e330028..993a045 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -21,7 +21,7 @@ from ormar.exceptions import ModelPersistenceError, RelationshipInstanceError from ormar.queryset.utils import translate_list_to_dict, update import ormar # noqa: I100 -from ormar.fields import BaseField +from ormar.fields import BaseField, ManyToManyField from ormar.fields.foreign_key import ForeignKeyField from ormar.models.metaclass import ModelMeta @@ -291,12 +291,21 @@ class ModelTableProxy: "ModelTableProxy", Type["ModelTableProxy"], ], + explicit_multi: bool = False, ) -> str: for name, field in item.Meta.model_fields.items(): - if issubclass(field, ForeignKeyField): - # fastapi is creating clones of response model - # that's why it can be a subclass of the original model - # so we need to compare Meta too as this one is copied as is + # fastapi is creating clones of response model + # that's why it can be a subclass of the original model + # so we need to compare Meta too as this one is copied as is + if issubclass(field, ManyToManyField): + attrib = "to" if not explicit_multi else "through" + if ( + getattr(field, attrib) == related.__class__ + or getattr(field, attrib).Meta == related.Meta + ): + return name + + elif issubclass(field, ForeignKeyField): if field.to == related.__class__ or field.to.Meta == related.Meta: return name diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 1a0b05f..b8955a3 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -96,7 +96,7 @@ class NewBaseModel( k: self._convert_json( k, self.Meta.model_fields[k].expand_relationship( - v, self, to_register=False + v, self, to_register=False, relation_name=k ), "dumps", ) @@ -125,7 +125,7 @@ class NewBaseModel( # register the columns models after initialization for related in self.extract_related_names(): self.Meta.model_fields[related].expand_relationship( - new_kwargs.get(related), self, to_register=True + new_kwargs.get(related), self, to_register=True, relation_name=related ) def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 @@ -135,7 +135,9 @@ class NewBaseModel( object.__setattr__(self, self.Meta.pkname, value) self.set_save_status(False) elif name in self._orm: - model = self.Meta.model_fields[name].expand_relationship(value, self) + model = self.Meta.model_fields[name].expand_relationship( + value=value, child=self, relation_name=name + ) if isinstance(self.__dict__.get(name), list): # virtual foreign key or many to many self.__dict__[name].append(model) diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 362ba85..e5f84f7 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -131,17 +131,19 @@ class QueryClause: # Walk the relationships to the actual model class # against which the comparison is being made. - previous_table = model_cls.Meta.tablename + previous_model = model_cls for part in related_parts: + part2 = part if issubclass(model_cls.Meta.model_fields[part], ManyToManyField): - previous_table = model_cls.Meta.model_fields[ - part - ].through.Meta.tablename - current_table = model_cls.Meta.model_fields[part].to.Meta.tablename + through_field = model_cls.Meta.model_fields[part] + previous_model = through_field.through + part2 = model_cls.resolve_relation_name( + through_field.through, through_field.to, explicit_multi=True + ) manager = model_cls.Meta.alias_manager - table_prefix = manager.resolve_relation_join(previous_table, current_table) + table_prefix = manager.resolve_relation_join_new(previous_model, part2) model_cls = model_cls.Meta.model_fields[part].to - previous_table = current_table + previous_model = model_cls return select_related, table_prefix, model_cls def _compile_clause( diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index 1628017..fa32fd6 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -135,8 +135,8 @@ class SqlJoin: model_cls = join_params.model_cls.Meta.model_fields[part].to to_table = model_cls.Meta.table.name - alias = model_cls.Meta.alias_manager.resolve_relation_join( - join_params.from_table, to_table + alias = model_cls.Meta.alias_manager.resolve_relation_join_new( + join_params.prev_model, part ) if alias not in self.used_aliases: self._process_join( @@ -267,7 +267,9 @@ class SqlJoin: model_cls, join_params.prev_model ) to_key = model_cls.get_column_alias(to_field) - from_key = join_params.prev_model.get_column_alias(model_cls.Meta.pkname) + from_key = join_params.prev_model.get_column_alias( + join_params.prev_model.Meta.pkname + ) else: to_key = model_cls.get_column_alias(model_cls.Meta.pkname) from_key = join_params.prev_model.get_column_alias(part) diff --git a/ormar/queryset/prefetch_query.py b/ormar/queryset/prefetch_query.py index 13ad785..1a06d31 100644 --- a/ormar/queryset/prefetch_query.py +++ b/ormar/queryset/prefetch_query.py @@ -318,9 +318,8 @@ class PrefetchQuery: if issubclass(target_field, ManyToManyField): query_target = target_field.through select_related = [target_name] - table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join( - from_table=query_target.Meta.tablename, - to_table=target_field.to.Meta.tablename, + table_prefix = target_field.to.Meta.alias_manager.resolve_relation_join_new( + query_target, target_name ) self.already_extracted.setdefault(target_name, {})["prefix"] = table_prefix diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index 2eead1e..477217f 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -1,11 +1,14 @@ import string import uuid from random import choices -from typing import Dict, List +from typing import Dict, List, TYPE_CHECKING, Type import sqlalchemy from sqlalchemy import text +if TYPE_CHECKING: # pragma: no cover + from ormar import Model + def get_table_alias() -> str: alias = "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] @@ -15,6 +18,7 @@ def get_table_alias() -> str: class AliasManager: def __init__(self) -> None: self._aliases: Dict[str, str] = dict() + self._aliases_new: Dict[str, str] = dict() @staticmethod def prefixed_columns( @@ -35,11 +39,25 @@ class AliasManager: def prefixed_table_name(alias: str, name: str) -> text: return text(f"{name} {alias}_{name}") - def add_relation_type(self, to_table_name: str, table_name: str,) -> None: - if f"{table_name}_{to_table_name}" not in self._aliases: - self._aliases[f"{table_name}_{to_table_name}"] = get_table_alias() - if f"{to_table_name}_{table_name}" not in self._aliases: - self._aliases[f"{to_table_name}_{table_name}"] = get_table_alias() + def add_relation_type_new( + self, source_model: Type["Model"], relation_name: str, is_multi: bool = False + ) -> None: + parent_key = f"{source_model.get_name()}_{relation_name}" + if parent_key not in self._aliases_new: + self._aliases_new[parent_key] = get_table_alias() + to_field = source_model.Meta.model_fields[relation_name] + child_model = to_field.to + related_name = to_field.related_name + if not related_name: + related_name = child_model.resolve_relation_name( + child_model, source_model, explicit_multi=is_multi + ) + child_key = f"{child_model.get_name()}_{related_name}" + if child_key not in self._aliases_new: + self._aliases_new[child_key] = get_table_alias() - def resolve_relation_join(self, from_table: str, to_table: str) -> str: - return self._aliases.get(f"{from_table}_{to_table}", "") + def resolve_relation_join_new( + self, from_model: Type["Model"], relation_name: str + ) -> str: + alias = self._aliases_new.get(f"{from_model.get_name()}_{relation_name}", "") + return alias diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index b0183c3..95b49fd 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional, TYPE_CHECKING, Type, TypeVar, Union +from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union import ormar # noqa I100 from ormar.exceptions import RelationshipInstanceError # noqa I100 @@ -31,6 +31,7 @@ class Relation: self.manager = manager self._owner: "Model" = manager.owner self._type: RelationType = type_ + self._to_remove: Set = set() self.to: Type["T"] = to self.through: Optional[Type["T"]] = through self.related_models: Optional[Union[RelationProxy, "T"]] = ( @@ -39,17 +40,32 @@ class Relation: else None ) + def _clean_related(self) -> None: + cleaned_data = [ + x + for i, x in enumerate(self.related_models) # type: ignore + if i not in self._to_remove + ] + self.related_models = RelationProxy( + relation=self, type_=self._type, data_=cleaned_data + ) + relation_name = self._owner.resolve_relation_name(self._owner, self.to) + self._owner.__dict__[relation_name] = cleaned_data + self._to_remove = set() + def _find_existing( self, child: Union["NewBaseModel", Type["NewBaseModel"]] ) -> Optional[int]: if not isinstance(self.related_models, RelationProxy): # pragma nocover raise ValueError("Cannot find existing models in parent relation type") + if self._to_remove: + self._clean_related() for ind, relation_child in enumerate(self.related_models[:]): try: if relation_child == child: return ind except ReferenceError: # pragma no cover - self.related_models.pop(ind) + self._to_remove.add(ind) return None def add(self, child: "T") -> None: @@ -83,4 +99,6 @@ class Relation: return self.related_models def __repr__(self) -> str: # pragma no cover + if self._to_remove: + self._clean_related() return str(self.related_models) diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index 81183f4..6462e48 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -56,8 +56,14 @@ class RelationsManager: return None @staticmethod - def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: - to_field: Type[BaseField] = child.resolve_relation_field(child, parent) + def add( + parent: "Model", + child: "Model", + child_name: str, + virtual: bool, + relation_name: str, + ) -> None: + to_field: Type[BaseField] = child.Meta.model_fields[relation_name] (parent, child, child_name, to_name,) = get_relations_sides_and_names( to_field, parent, child, child_name, virtual diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index c8eb944..28f177f 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -11,8 +11,10 @@ if TYPE_CHECKING: # pragma no cover class RelationProxy(list): - def __init__(self, relation: "Relation", type_: "RelationType") -> None: - super().__init__() + def __init__( + self, relation: "Relation", type_: "RelationType", data_: Any = None + ) -> None: + super().__init__(data_ or ()) self.relation: "Relation" = relation self.type_: "RelationType" = type_ self._owner: "Model" = self.relation.manager.owner diff --git a/ormar/relations/utils.py b/ormar/relations/utils.py index c7a3fff..9fa09b0 100644 --- a/ormar/relations/utils.py +++ b/ormar/relations/utils.py @@ -18,8 +18,11 @@ def get_relations_sides_and_names( to_name = to_field.name if issubclass(to_field, ManyToManyField): child_name, to_name = ( - child.resolve_relation_name(parent, child), - child.resolve_relation_name(child, parent), + to_field.related_name + or child.resolve_relation_name( + parent, to_field.through, explicit_multi=True + ), + to_name, ) child = proxy(child) elif virtual: diff --git a/tests/test_inheritance_concrete.py b/tests/test_inheritance_concrete.py index 65cd3e8..978c020 100644 --- a/tests/test_inheritance_concrete.py +++ b/tests/test_inheritance_concrete.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from sqlalchemy import create_engine import ormar -from ormar import ModelDefinitionError +from ormar import ModelDefinitionError, property_field from ormar.exceptions import ModelError from tests.settings import DATABASE_URL @@ -24,6 +24,10 @@ class AuditModel(ormar.Model): created_by: str = ormar.String(max_length=100) updated_by: str = ormar.String(max_length=100, default="Sam") + @property_field + def audit(self): # pragma: no cover + return f"{self.created_by} {self.updated_by}" + class DateFieldsModelNoSubclass(ormar.Model): class Meta: @@ -41,6 +45,7 @@ class DateFieldsModel(ormar.Model): abstract = True metadata = metadata database = db + constraints = [ormar.UniqueColumns("created_date", "updated_date")] created_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) updated_date: datetime.datetime = ormar.DateTime(default=datetime.datetime.now) @@ -49,11 +54,20 @@ class DateFieldsModel(ormar.Model): class Category(DateFieldsModel, AuditModel): class Meta(ormar.ModelMeta): tablename = "categories" + constraints = [ormar.UniqueColumns("name", "code")] id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=50, unique=True, index=True) code: int = ormar.Integer() + @property_field + def code_name(self): + return f"{self.code}:{self.name}" + + @property_field + def audit(self): + return f"{self.created_by} {self.updated_by}" + class Subject(DateFieldsModel): class Meta(ormar.ModelMeta): @@ -99,6 +113,13 @@ def test_model_subclassing_non_abstract_raises_error(): id: int = ormar.Integer(primary_key=True) +def test_params_are_inherited(): + assert Category.Meta.metadata == metadata + assert Category.Meta.database == db + assert len(Category.Meta.constraints) == 2 + assert len(Category.Meta.property_fields) == 2 + + def round_date_to_seconds( date: datetime.datetime, ) -> datetime.datetime: # pragma: no cover @@ -132,7 +153,9 @@ async def test_fields_inherited_from_mixin(): inspector = sa.inspect(engine) assert "categories" in inspector.get_table_names() table_columns = [x.get("name") for x in inspector.get_columns("categories")] - assert all(col in table_columns for col in mixin_columns) # + mixin2_columns) + assert all( + col in table_columns for col in mixin_columns + ) # + mixin2_columns) assert "subjects" in inspector.get_table_names() table_columns = [x.get("name") for x in inspector.get_columns("subjects")] diff --git a/tests/test_inheritance_concrete_fastapi.py b/tests/test_inheritance_concrete_fastapi.py index c48a223..c568d37 100644 --- a/tests/test_inheritance_concrete_fastapi.py +++ b/tests/test_inheritance_concrete_fastapi.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient from tests.settings import DATABASE_URL -from tests.test_inheritance_concrete import Category, Subject, metadata +from tests.test_inheritance_concrete import Category, Subject, metadata # type: ignore app = FastAPI() database = databases.Database(DATABASE_URL, force_rollback=True) @@ -53,25 +53,25 @@ def test_read_main(): test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max") test_subject = dict(name="Bar") - response = client.post( - "/categories/", json=test_category - ) + response = client.post("/categories/", json=test_category) assert response.status_code == 200 cat = Category(**response.json()) - assert cat.name == 'Foo' - assert cat.created_by == 'Sam' + assert cat.name == "Foo" + assert cat.created_by == "Sam" assert cat.created_date is not None assert cat.id == 1 cat_dict = cat.dict() - cat_dict['updated_date'] = cat_dict['updated_date'].strftime("%Y-%m-%d %H:%M:%S.%f") - cat_dict['created_date'] = cat_dict['created_date'].strftime("%Y-%m-%d %H:%M:%S.%f") - test_subject['category'] = cat_dict - response = client.post( - "/subjects/", json=test_subject + cat_dict["updated_date"] = cat_dict["updated_date"].strftime( + "%Y-%m-%d %H:%M:%S.%f" ) + cat_dict["created_date"] = cat_dict["created_date"].strftime( + "%Y-%m-%d %H:%M:%S.%f" + ) + test_subject["category"] = cat_dict + response = client.post("/subjects/", json=test_subject) assert response.status_code == 200 sub = Subject(**response.json()) - assert sub.name == 'Bar' + assert sub.name == "Bar" assert sub.category.pk == cat.pk assert isinstance(sub.updated_date, datetime.datetime) diff --git a/tests/test_inheritance_mixins_fastapi.py b/tests/test_inheritance_mixins_fastapi.py index 6bece1d..884f839 100644 --- a/tests/test_inheritance_mixins_fastapi.py +++ b/tests/test_inheritance_mixins_fastapi.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient from tests.settings import DATABASE_URL -from tests.test_inheritance_mixins import Category, Subject, metadata +from tests.test_inheritance_mixins import Category, Subject, metadata # type: ignore app = FastAPI() database = databases.Database(DATABASE_URL, force_rollback=True) @@ -53,25 +53,25 @@ def test_read_main(): test_category = dict(name="Foo", code=123, created_by="Sam", updated_by="Max") test_subject = dict(name="Bar") - response = client.post( - "/categories/", json=test_category - ) + response = client.post("/categories/", json=test_category) assert response.status_code == 200 cat = Category(**response.json()) - assert cat.name == 'Foo' - assert cat.created_by == 'Sam' + assert cat.name == "Foo" + assert cat.created_by == "Sam" assert cat.created_date is not None assert cat.id == 1 cat_dict = cat.dict() - cat_dict['updated_date'] = cat_dict['updated_date'].strftime("%Y-%m-%d %H:%M:%S.%f") - cat_dict['created_date'] = cat_dict['created_date'].strftime("%Y-%m-%d %H:%M:%S.%f") - test_subject['category'] = cat_dict - response = client.post( - "/subjects/", json=test_subject + cat_dict["updated_date"] = cat_dict["updated_date"].strftime( + "%Y-%m-%d %H:%M:%S.%f" ) + cat_dict["created_date"] = cat_dict["created_date"].strftime( + "%Y-%m-%d %H:%M:%S.%f" + ) + test_subject["category"] = cat_dict + response = client.post("/subjects/", json=test_subject) assert response.status_code == 200 sub = Subject(**response.json()) - assert sub.name == 'Bar' + assert sub.name == "Bar" assert sub.category.pk == cat.pk assert isinstance(sub.updated_date, datetime.datetime) diff --git a/tests/test_query_with_m2m_and_diff_pk_name.py b/tests/test_query_with_m2m_and_diff_pk_name.py deleted file mode 100644 index a12e4cf..0000000 --- a/tests/test_query_with_m2m_and_diff_pk_name.py +++ /dev/null @@ -1,115 +0,0 @@ -import asyncio -from datetime import date -from typing import List, Optional, Union - -import databases -import pytest -import sqlalchemy - -import ormar - -from tests.settings import DATABASE_URL - -database = databases.Database(DATABASE_URL) -metadata = sqlalchemy.MetaData() - - -class MainMeta(ormar.ModelMeta): - metadata = metadata - database = database - - -class Role(ormar.Model): - class Meta(MainMeta): - pass - - name: str = ormar.Text(primary_key=True) - order: int = ormar.Integer(default=0) - description: str = ormar.Text() - - -class Company(ormar.Model): - class Meta(MainMeta): - pass - - name: str = ormar.Text(primary_key=True) - - -class UserRoleCompany(ormar.Model): - class Meta(MainMeta): - pass - - -class User(ormar.Model): - class Meta(MainMeta): - pass - - registrationnumber: str = ormar.Text(primary_key=True) - company: Company = ormar.ForeignKey(Company) - name: str = ormar.Text() - role: Optional[Role] = ormar.ForeignKey(Role) - roleforcompanies: Optional[Union[Company, List[Company]]] = ormar.ManyToMany(Company, through=UserRoleCompany) - lastupdate: date = ormar.DateTime(server_default=sqlalchemy.func.now()) - - -@pytest.mark.asyncio -async def test_create_primary_models(): - async with database: - print("adding role") - role_0 = await Role.objects.create(name="user", order=0, description="no administration right") - role_1 = await Role.objects.create(name="admin", order=1, description="standard administration right") - role_2 = await Role.objects.create(name="super_admin", order=2, description="super administration right") - assert await Role.objects.count() == 3 - - print("adding company") - company_0 = await Company.objects.create(name="Company") - company_1 = await Company.objects.create(name="Subsidiary Company 1") - company_2 = await Company.objects.create(name="Subsidiary Company 2") - company_3 = await Company.objects.create(name="Subsidiary Company 3") - assert await Company.objects.count() == 4 - - print("adding user") - user = await User.objects.create(registrationnumber="00-00000", company=company_0, name="admin", role=role_1) - assert await User.objects.count() == 1 - - print("removing user") - await user.delete() - assert await User.objects.count() == 0 - - print("adding user with company-role") - companies: List[Company] = [company_1, company_2] - # user = await User.objects.create(registrationnumber="00-00000", company=company_0, name="admin", role=role_1, roleforcompanies=companies) - user = await User.objects.create(registrationnumber="00-00000", company=company_0, name="admin", role=role_1) - # print(User.__fields__) - await user.roleforcompanies.add(company_1) - await user.roleforcompanies.add(company_2) - - users = await User.objects.select_related("roleforcompanies").all() - # print(jsonpickle.encode(jsonable_encoder(users), unpicklable=False, keys=True)) - - """ - - This is the request generated: - 'SELECT - users.registrationnumber as registrationnumber, - users.company as company, - users.name as name, users.role as role, - users.lastupdate as lastupdate, - cy24b4_userrolecompanys.id as cy24b4_id, - cy24b4_userrolecompanys.company as cy24b4_company, - cy24b4_userrolecompanys.user as cy24b4_user, - jn50a4_companys.name as jn50a4_name \n - FROM users - LEFT OUTER JOIN userrolecompanys cy24b4_userrolecompanys ON cy24b4_userrolecompanys.user=users.id - LEFT OUTER JOIN companys jn50a4_companys ON jn50a4_companys.name=cy24b4_userrolecompanys.company - ORDER BY users.registrationnumber, jn50a4_companys.name' - - There is an error in the First LEFT OUTER JOIN generated: - ... companys.user=users.id - should be: - ... companys.user=users.registrationnumber - - There is also a \n in the midle of the string... - - The execution produce the error: column users.id does not exist - """ diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 2375613..afba731 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -102,7 +102,7 @@ async def test_model_multiple_instances_of_same_table_in_schema(): async with database.transaction(force_rollback=True): await create_data() classes = await SchoolClass.objects.select_related( - ["teachers__category", "students"] + ["teachers__category", "students__schoolclass"] ).all() assert classes[0].name == "Math" assert classes[0].students[0].name == "Jane" diff --git a/tests/test_select_related_with_limit.py b/tests/test_select_related_with_limit.py index 27e4460..7fd801e 100644 --- a/tests/test_select_related_with_limit.py +++ b/tests/test_select_related_with_limit.py @@ -56,8 +56,7 @@ class SecondaryModel(ormar.Model): id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100) primary_model: PrimaryModel = ormar.ForeignKey( - PrimaryModel, - related_name="secondary_models", + PrimaryModel, related_name="secondary_models", ) @@ -74,7 +73,8 @@ async def test_create_primary_models(): ("Primary 7", "Some text 7", "Some other text 7"), ("Primary 8", "Some text 8", "Some other text 8"), ("Primary 9", "Some text 9", "Some other text 9"), - ("Primary 10", "Some text 10", "Some other text 10")]: + ("Primary 10", "Some text 10", "Some other text 10"), + ]: await PrimaryModel( name=name, some_text=some_text, some_other_text=some_other_text ).save() diff --git a/tests/test_select_related_with_m2m_and_pk_name_set.py b/tests/test_select_related_with_m2m_and_pk_name_set.py new file mode 100644 index 0000000..85cdea6 --- /dev/null +++ b/tests/test_select_related_with_m2m_and_pk_name_set.py @@ -0,0 +1,130 @@ +# type: ignore +from datetime import date +from typing import List, Optional, Union + +import databases +import pytest +import sqlalchemy +from sqlalchemy import create_engine + +import ormar +from ormar import ModelDefinitionError +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL) +metadata = sqlalchemy.MetaData() + + +class MainMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class Role(ormar.Model): + class Meta(MainMeta): + pass + + name: str = ormar.String(primary_key=True, max_length=1000) + order: int = ormar.Integer(default=0, name="sort_order") + description: str = ormar.Text() + + +class Company(ormar.Model): + class Meta(MainMeta): + pass + + name: str = ormar.String(primary_key=True, max_length=1000) + + +class UserRoleCompany(ormar.Model): + class Meta(MainMeta): + pass + + +class User(ormar.Model): + class Meta(MainMeta): + pass + + registrationnumber: str = ormar.String(primary_key=True, max_length=1000) + company: Company = ormar.ForeignKey(Company) + company2: Company = ormar.ForeignKey(Company, related_name="secondary_users") + name: str = ormar.Text() + role: Optional[Role] = ormar.ForeignKey(Role) + roleforcompanies: Optional[Union[Company, List[Company]]] = ormar.ManyToMany( + Company, through=UserRoleCompany, related_name="role_users" + ) + lastupdate: date = ormar.DateTime(server_default=sqlalchemy.func.now()) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_wrong_model(): + with pytest.raises(ModelDefinitionError): + + class User(ormar.Model): + class Meta(MainMeta): + pass + + registrationnumber: str = ormar.Text(primary_key=True) + company: Company = ormar.ForeignKey(Company) + company2: Company = ormar.ForeignKey(Company) + + +@pytest.mark.asyncio +async def test_create_primary_models(): + async with database: + await Role.objects.create( + name="user", order=0, description="no administration right" + ) + role_1 = await Role.objects.create( + name="admin", order=1, description="standard administration right" + ) + await Role.objects.create( + name="super_admin", order=2, description="super administration right" + ) + assert await Role.objects.count() == 3 + + company_0 = await Company.objects.create(name="Company") + company_1 = await Company.objects.create(name="Subsidiary Company 1") + company_2 = await Company.objects.create(name="Subsidiary Company 2") + company_3 = await Company.objects.create(name="Subsidiary Company 3") + assert await Company.objects.count() == 4 + + user = await User.objects.create( + registrationnumber="00-00000", company=company_0, name="admin", role=role_1 + ) + assert await User.objects.count() == 1 + + await user.delete() + assert await User.objects.count() == 0 + + user = await User.objects.create( + registrationnumber="00-00000", + company=company_0, + company2=company_3, + name="admin", + role=role_1, + ) + await user.roleforcompanies.add(company_1) + await user.roleforcompanies.add(company_2) + + users = await User.objects.select_related( + ["company", "company2", "roleforcompanies"] + ).all() + assert len(users) == 1 + assert len(users[0].roleforcompanies) == 2 + assert len(users[0].roleforcompanies[0].role_users) == 1 + assert users[0].company.name == "Company" + assert len(users[0].company.users) == 1 + assert users[0].company2.name == "Subsidiary Company 3" + assert len(users[0].company2.secondary_users) == 1 + + users = await User.objects.select_related("roleforcompanies").all() + assert len(users) == 1 + assert len(users[0].roleforcompanies) == 2 diff --git a/tests/test_selecting_proper_table_prefix.py b/tests/test_selecting_proper_table_prefix.py new file mode 100644 index 0000000..c4e93ec --- /dev/null +++ b/tests/test_selecting_proper_table_prefix.py @@ -0,0 +1,93 @@ +from typing import List, Optional + +import databases +import pytest +import sqlalchemy +from sqlalchemy import create_engine + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL) +metadata = sqlalchemy.MetaData() + + +class User(ormar.Model): + class Meta: + metadata = metadata + database = database + tablename = "test_users" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=50) + + +class Signup(ormar.Model): + class Meta: + metadata = metadata + database = database + tablename = "test_signup" + + id: int = ormar.Integer(primary_key=True) + + +class Session(ormar.Model): + class Meta: + metadata = metadata + database = database + tablename = "test_sessions" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=255, index=True) + some_text: str = ormar.Text() + some_other_text: Optional[str] = ormar.Text(nullable=True) + students: Optional[List[User]] = ormar.ManyToMany(User, through=Signup) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_list_sessions_for_user(): + async with database: + for user_id in [1, 2, 3, 4, 5]: + await User.objects.create(name=f"User {user_id}") + + for name, some_text, some_other_text in [ + ("Session 1", "Some text 1", "Some other text 1"), + ("Session 2", "Some text 2", "Some other text 2"), + ("Session 3", "Some text 3", "Some other text 3"), + ("Session 4", "Some text 4", "Some other text 4"), + ("Session 5", "Some text 5", "Some other text 5"), + ]: + await Session( + name=name, some_text=some_text, some_other_text=some_other_text + ).save() + + s1 = await Session.objects.get(pk=1) + s2 = await Session.objects.get(pk=2) + + users = {} + for i in range(1, 6): + user = await User.objects.get(pk=i) + users[f"user_{i}"] = user + if i % 2 == 0: + await s1.students.add(user) + else: + await s2.students.add(user) + + assert len(s1.students) == 2 + assert len(s2.students) == 3 + + assert [x.pk for x in s1.students] == [2, 4] + assert [x.pk for x in s2.students] == [1, 3, 5] + + user = await User.objects.select_related("sessions").get(pk=1) + + assert user.sessions is not None + assert len(user.sessions) > 0