diff --git a/.coverage b/.coverage index 317b35b..0c5ad2c 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/fields/base.py b/ormar/fields/base.py index d3f5e6c..88e1313 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -64,7 +64,7 @@ class BaseField: @classmethod def get_column(cls, name: str) -> sqlalchemy.Column: return sqlalchemy.Column( - name, + cls.name or name, cls.column_type, *cls.constraints, primary_key=cls.primary_key, diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 957c9e5..2959f32 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -38,7 +38,7 @@ def ForeignKey( # noqa CFQ002 onupdate: str = None, ondelete: str = None, ) -> Type["ForeignKeyField"]: - fk_string = to.Meta.tablename + "." + to.Meta.pkname + fk_string = to.Meta.tablename + "." + to.get_column_alias(to.Meta.pkname) to_field = to.__fields__[to.Meta.pkname] namespace = dict( to=to, diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 587d92f..5462865 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -12,7 +12,7 @@ from ormar.fields.base import BaseField # noqa I101 def is_field_nullable( - nullable: Optional[bool], default: Any, server_default: Any + nullable: Optional[bool], default: Any, server_default: Any ) -> bool: if nullable is None: return default is not None or server_default is not None @@ -61,15 +61,15 @@ class String(ModelFieldFactory): _type = str def __new__( # type: ignore # noqa CFQ002 - cls, - *, - allow_blank: bool = True, - strip_whitespace: bool = False, - min_length: int = None, - max_length: int = None, - curtail_length: int = None, - regex: str = None, - **kwargs: Any + cls, + *, + allow_blank: bool = True, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + **kwargs: Any ) -> Type[BaseField]: # type: ignore kwargs = { **kwargs, @@ -79,7 +79,7 @@ class String(ModelFieldFactory): if k not in ["cls", "__class__", "kwargs"] }, } - kwargs['allow_blank'] = kwargs.get('nullable', True) + kwargs["allow_blank"] = kwargs.get("nullable", True) return super().__new__(cls, **kwargs) @classmethod @@ -100,12 +100,12 @@ class Integer(ModelFieldFactory): _type = int def __new__( # type: ignore - cls, - *, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - **kwargs: Any + cls, + *, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + **kwargs: Any ) -> Type[BaseField]: autoincrement = kwargs.pop("autoincrement", None) autoincrement = ( @@ -135,7 +135,7 @@ class Text(ModelFieldFactory): _type = str def __new__( # type: ignore - cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any + cls, *, allow_blank: bool = True, strip_whitespace: bool = False, **kwargs: Any ) -> Type[BaseField]: kwargs = { **kwargs, @@ -145,7 +145,7 @@ class Text(ModelFieldFactory): if k not in ["cls", "__class__", "kwargs"] }, } - kwargs['allow_blank'] = kwargs.get('nullable', True) + kwargs["allow_blank"] = kwargs.get("nullable", True) return super().__new__(cls, **kwargs) @classmethod @@ -158,12 +158,12 @@ class Float(ModelFieldFactory): _type = float def __new__( # type: ignore - cls, - *, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - **kwargs: Any + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + **kwargs: Any ) -> Type[BaseField]: kwargs = { **kwargs, @@ -232,12 +232,12 @@ class BigInteger(Integer): _type = int def __new__( # type: ignore - cls, - *, - minimum: int = None, - maximum: int = None, - multiple_of: int = None, - **kwargs: Any + cls, + *, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + **kwargs: Any ) -> Type[BaseField]: autoincrement = kwargs.pop("autoincrement", None) autoincrement = ( @@ -267,16 +267,16 @@ class Decimal(ModelFieldFactory): _type = decimal.Decimal def __new__( # type: ignore # noqa CFQ002 - cls, - *, - minimum: float = None, - maximum: float = None, - multiple_of: int = None, - precision: int = None, - scale: int = None, - max_digits: int = None, - decimal_places: int = None, - **kwargs: Any + cls, + *, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + **kwargs: Any ) -> Type[BaseField]: kwargs = { **kwargs, diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 0bd711d..d7c57b9 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -41,7 +41,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -50,11 +50,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -65,7 +65,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -73,10 +73,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -91,7 +91,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, name=model.get_name(), ondelete="CASCADE" @@ -108,7 +108,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -120,13 +120,13 @@ def create_pydantic_field( def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), - model.Meta.table.columns.get(model.Meta.pkname).type, + model.Meta.table.columns.get(model.get_column_alias(model.Meta.pkname)).type, sqlalchemy.schema.ForeignKey( - model.Meta.tablename + "." + model.Meta.pkname, + model.Meta.tablename + "." + model.get_column_alias(model.Meta.pkname), ondelete="CASCADE", onupdate="CASCADE", ), @@ -136,7 +136,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -146,7 +146,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -160,9 +160,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field_name)) register_relation_in_alias_manager(table_name, field) @@ -170,7 +170,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -179,7 +179,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - type_: Type[BaseField], field: str, attrs: dict + type_: Type[BaseField], field: str, attrs: dict ) -> dict: def_value = type_.default_value() curr_def_value = attrs.get(field, "NONE") @@ -208,7 +208,7 @@ def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: def populate_meta_orm_model_fields( - attrs: dict, new_model: Type["Model"] + attrs: dict, new_model: Type["Model"] ) -> Type["Model"]: model_fields = { field_name: field @@ -220,10 +220,12 @@ def populate_meta_orm_model_fields( def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" - new_model.Meta.tablename = new_model.Meta.tablename if hasattr(new_model.Meta, 'tablename') else tablename + new_model.Meta.tablename = ( + new_model.Meta.tablename if hasattr(new_model.Meta, "tablename") else tablename + ) pkname: Optional[str] if hasattr(new_model.Meta, "columns"): @@ -244,7 +246,7 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( @@ -286,7 +288,7 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A def populate_choices_validators( # noqa CCR001 - model: Type["Model"], attrs: Dict + model: Type["Model"], attrs: Dict ) -> None: if model_initialized_and_has_model_fields(model): for _, field in model.Meta.model_fields.items(): @@ -299,7 +301,7 @@ def populate_choices_validators( # noqa CCR001 class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name diff --git a/ormar/models/model.py b/ormar/models/model.py index fd71efe..fea010f 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -90,13 +90,13 @@ class Model(NewBaseModel): previous_table=previous_table, fields=fields, ) - item[first_part] = child + item[model_cls.get_column_name_from_alias(first_part)] = child else: model_cls = cls.Meta.model_fields[related].to child = model_cls.from_row( row, previous_table=previous_table, fields=fields ) - item[related] = child + item[model_cls.get_column_name_from_alias(related)] = child return item @@ -113,13 +113,16 @@ class Model(NewBaseModel): # databases does not keep aliases in Record for postgres, change to raw row source = row._row if isinstance(row, Record) else row - selected_columns = cls.own_table_columns(cls, fields or [], nested=nested) + selected_columns = cls.own_table_columns( + cls, fields or [], nested=nested, use_alias=True + ) for column in cls.Meta.table.columns: - if column.name not in item and column.name in selected_columns: + alias = cls.get_column_name_from_alias(column.name) + if alias not in item and alias in selected_columns: prefixed_name = ( f'{table_prefix + "_" if table_prefix else ""}{column.name}' ) - item[column.name] = source[prefixed_name] + item[alias] = source[prefixed_name] return item diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 0dadbb2..2e24201 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -47,6 +47,20 @@ class ModelTableProxy: model_dict[field] = field_value.get(target_pkname) return model_dict + @classmethod + def get_column_alias(cls, field_name: str) -> str: + field = cls.Meta.model_fields.get(field_name) + if field and field.name is not None and field.name != field_name: + return field.name + return field_name + + @classmethod + def get_column_name_from_alias(cls, alias: str) -> str: + for field_name, field in cls.Meta.model_fields.items(): + if field and field.name == alias: + return field_name + return alias # if not found it's not an alias but actual name + @classmethod def extract_related_names(cls) -> Set: related_names = set() @@ -151,10 +165,16 @@ class ModelTableProxy: return other @staticmethod - def own_table_columns( - model: Type["Model"], fields: List, nested: bool = False + def own_table_columns( # noqa: CCR001 + model: Type["Model"], + fields: List, + nested: bool = False, + use_alias: bool = False, ) -> List[str]: - column_names = [col.name for col in model.Meta.table.columns] + column_names = [ + model.get_column_name_from_alias(col.name) if use_alias else col.name + for col in model.Meta.table.columns + ] if not fields: return column_names @@ -175,6 +195,11 @@ class ModelTableProxy: columns = column_names # always has to return pk column - if model.Meta.pkname not in columns: - columns.append(model.Meta.pkname) + pk_alias = ( + model.get_column_alias(model.Meta.pkname) + if use_alias + else model.Meta.pkname + ) + if pk_alias not in columns: + columns.append(pk_alias) return columns diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 140d118..88664fb 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -134,8 +134,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def _extract_related_model_instead_of_field( self, item: str ) -> Optional[Union["Model", List["Model"]]]: - if item in self._orm: - return self._orm.get(item) + alias = self.get_column_alias(item) + if alias in self._orm: + return self._orm.get(alias) return None def __eq__(self, other: object) -> bool: diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 6b9d5a4..c5d3cd4 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -44,7 +44,7 @@ class QueryClause: ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: if kwargs.get("pk"): - pk_name = self.model_cls.Meta.pkname + pk_name = self.model_cls.get_column_alias(self.model_cls.Meta.pkname) kwargs[pk_name] = kwargs.pop("pk") filter_clauses, select_related = self._populate_filter_clauses(**kwargs) diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index f79f353..42803ff 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -106,7 +106,9 @@ class SqlJoin: self.select_from = sqlalchemy.sql.outerjoin( self.select_from, target_table, on_clause ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) + + pkname_alias = model_cls.get_column_alias(model_cls.Meta.pkname) + self.order_bys.append(text(f"{alias}_{to_table}.{pkname_alias}")) self_related_fields = model_cls.own_table_columns( model_cls, self.fields, nested=True ) @@ -125,12 +127,13 @@ class SqlJoin: part: str, ) -> Tuple[str, str]: if join_params.prev_model.Meta.model_fields[part].virtual or is_multi: - to_field = model_cls.resolve_relation_field( + to_field = model_cls.resolve_relation_name( model_cls, join_params.prev_model ) - to_key = to_field.name - from_key = model_cls.Meta.pkname + to_key = model_cls.get_column_alias(to_field) + from_key = join_params.prev_model.get_column_alias(model_cls.Meta.pkname) else: - to_key = model_cls.Meta.pkname - from_key = part + to_key = model_cls.get_column_alias(model_cls.Meta.pkname) + from_key = join_params.prev_model.get_column_alias(part) + return to_key, from_key diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index ada3437..880b3c1 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -40,7 +40,8 @@ class Query: @property def prefixed_pk_name(self) -> str: - return f"{self.table.name}.{self.model_cls.Meta.pkname}" + pkname_alias = self.model_cls.get_column_alias(self.model_cls.Meta.pkname) + return f"{self.table.name}.{pkname_alias}" def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self_related_fields = self.model_cls.own_table_columns( diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index adece00..676374e 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -70,12 +70,29 @@ class QuerySet: return self.model.merge_instances_list(result_rows) # type: ignore return result_rows + def _prepare_model_to_save(self, new_kwargs: dict) -> dict: + new_kwargs = self._remove_pk_from_kwargs(new_kwargs) + new_kwargs = self.model.substitute_models_with_pks(new_kwargs) + new_kwargs = self._populate_default_values(new_kwargs) + new_kwargs = self._translate_columns_to_aliases(new_kwargs) + return new_kwargs + def _populate_default_values(self, new_kwargs: dict) -> dict: for field_name, field in self.model_meta.model_fields.items(): if field_name not in new_kwargs and field.has_default(): new_kwargs[field_name] = field.get_default() return new_kwargs + def _translate_columns_to_aliases(self, new_kwargs: dict) -> dict: + for field_name, field in self.model_meta.model_fields.items(): + if ( + field_name in new_kwargs + and field.name is not None + and field.name != field_name + ): + new_kwargs[field.name] = new_kwargs.pop(field_name) + return new_kwargs + def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: pkname = self.model_meta.pkname pk = self.model_meta.model_fields[pkname] @@ -278,9 +295,7 @@ class QuerySet: async def create(self, **kwargs: Any) -> "Model": new_kwargs = dict(**kwargs) - new_kwargs = self._remove_pk_from_kwargs(new_kwargs) - new_kwargs = self.model.substitute_models_with_pks(new_kwargs) - new_kwargs = self._populate_default_values(new_kwargs) + new_kwargs = self._prepare_model_to_save(new_kwargs) expr = self.table.insert() expr = expr.values(**new_kwargs) @@ -288,7 +303,7 @@ class QuerySet: instance = self.model(**kwargs) pk = await self.database.execute(expr) - pk_name = self.model_meta.pkname + pk_name = self.model.get_column_alias(self.model_meta.pkname) if pk_name not in kwargs and pk_name in new_kwargs: instance.pk = new_kwargs[self.model_meta.pkname] if pk and isinstance(pk, self.model.pk_type()): @@ -300,9 +315,7 @@ class QuerySet: ready_objects = [] for objt in objects: new_kwargs = objt.dict() - new_kwargs = self._remove_pk_from_kwargs(new_kwargs) - new_kwargs = self.model.substitute_models_with_pks(new_kwargs) - new_kwargs = self._populate_default_values(new_kwargs) + new_kwargs = self._prepare_model_to_save(new_kwargs) ready_objects.append(new_kwargs) expr = self.table.insert() diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 3863679..88130d5 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -39,7 +39,7 @@ class RelationProxy(list): def _set_queryset(self) -> "QuerySet": owner_table = self.relation._owner.Meta.tablename - pkname = self.relation._owner.Meta.pkname + pkname = self.relation._owner.get_column_alias(self.relation._owner.Meta.pkname) pk_value = self.relation._owner.pk if not pk_value: raise RelationshipInstanceError( diff --git a/tests/test_aliases.py b/tests/test_aliases.py new file mode 100644 index 0000000..f8359d9 --- /dev/null +++ b/tests/test_aliases.py @@ -0,0 +1,95 @@ +import databases +import pytest +import sqlalchemy + +import ormar +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Child(ormar.Model): + class Meta: + tablename = "children" + metadata = metadata + database = database + + id: ormar.Integer(name='child_id', primary_key=True) + first_name: ormar.String(name='fname', max_length=100) + last_name: ormar.String(name='lname', max_length=100) + born_year: ormar.Integer(name='year_born') + + +class ArtistChildren(ormar.Model): + class Meta: + tablename = "children_x_artists" + metadata = metadata + database = database + + +class Artist(ormar.Model): + class Meta: + tablename = "artists" + metadata = metadata + database = database + + id: ormar.Integer(name='artist_id', primary_key=True) + first_name: ormar.String(name='fname', max_length=100) + last_name: ormar.String(name='lname', max_length=100) + born_year: ormar.Integer(name='year') + children: ormar.ManyToMany(Child, through=ArtistChildren) + + +class Album(ormar.Model): + class Meta: + tablename = "music_albums" + metadata = metadata + database = database + + id: ormar.Integer(name='album_id', primary_key=True) + name: ormar.String(name='album_name', max_length=100) + artist: ormar.ForeignKey(Artist, name='artist_id') + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_table_structure(): + assert 'album_id' in [x.name for x in Album.Meta.table.columns] + assert 'album_name' in [x.name for x in Album.Meta.table.columns] + assert 'fname' in [x.name for x in Artist.Meta.table.columns] + assert 'lname' in [x.name for x in Artist.Meta.table.columns] + assert 'year' in [x.name for x in Artist.Meta.table.columns] + + +@pytest.mark.asyncio +async def test_working_with_aliases(): + async with database: + async with database.transaction(force_rollback=True): + artist = await Artist.objects.create(first_name='Ted', last_name='Mosbey', born_year=1975) + await Album.objects.create(name="Aunt Robin", artist=artist) + + await artist.children.create(first_name='Son', last_name='1', born_year=1990) + await artist.children.create(first_name='Son', last_name='2', born_year=1995) + + album = await Album.objects.select_related('artist').first() + assert album.artist.last_name == 'Mosbey' + + assert album.artist.id is not None + assert album.artist.first_name == 'Ted' + assert album.artist.born_year == 1975 + + assert album.name == 'Aunt Robin' + + artist = await Artist.objects.select_related('children').get() + assert len(artist.children) == 2 + assert artist.children[0].first_name == 'Son' + assert artist.children[1].last_name == '2' +