diff --git a/.coverage b/.coverage index 6db7a66..ecaa279 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/fields.py b/orm/fields.py index 9c9f3f7..eece9ea 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,6 +1,6 @@ import datetime import decimal -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, Union import orm from orm.exceptions import ModelDefinitionError, RelationshipInstanceError @@ -29,6 +29,9 @@ class BaseField: name = args.pop(0) self.name = name + self._populate_from_kwargs(kwargs) + + def _populate_from_kwargs(self, kwargs: Dict) -> None: self.primary_key = kwargs.pop("primary_key", False) self.autoincrement = kwargs.pop( "autoincrement", self.primary_key and self.__type__ == int @@ -79,7 +82,7 @@ class BaseField: index=self.index, unique=self.unique, default=self.default, - server_default=self.server_default + server_default=self.server_default, ) def get_column_type(self) -> sqlalchemy.types.TypeEngine: @@ -228,13 +231,13 @@ class ForeignKey(BaseField): def expand_relationship( self, value: Any, child: "Model" ) -> Union["Model", List["Model"]]: - if not isinstance(value, (self.to, dict, int, str, list)) or ( - isinstance(value, orm.models.Model) and not isinstance(value, self.to) - ): + + if isinstance(value, orm.models.Model) and not isinstance(value, self.to): raise RelationshipInstanceError( - "Relationship model can be build only from orm.Model, " - "dict and integer or string (pk)." + f"Relationship error - expecting: {self.to.__name__}, " + f"but {value.__class__.__name__} encountered." ) + if isinstance(value, list) and not isinstance(value, self.to): model = [self.expand_relationship(val, child) for val in value] return model @@ -244,9 +247,19 @@ class ForeignKey(BaseField): elif isinstance(value, dict): model = self.to(**value) else: + if not isinstance(value, self.to.pk_type()): + raise RelationshipInstanceError( + f"Relationship error - ForeignKey {self.to.__name__} is of type {self.to.pk_type()} " + f"of type {self.__type__} while {type(value)} passed as a parameter." + ) model = create_dummy_instance(fk=self.to, pk=value) - child_model_name = self.related_name or child.__class__.__name__.lower() + "s" + self.add_to_relationship_registry(model, child) + + return model + + def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: + child_model_name = self.related_name or child.get_name() + "s" model._orm_relationship_manager.add_relation( model.__class__.__name__.lower(), child.__class__.__name__.lower(), @@ -257,16 +270,20 @@ class ForeignKey(BaseField): if ( child_model_name not in model.__fields__ - and child.__class__.__name__.lower() not in model.__fields__ + and child.get_name() not in model.__fields__ ): - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( - child.__class__, name=child_model_name, virtual=True - ) + self.register_reverse_model_fields(model, child, child_model_name) - return model + @staticmethod + def register_reverse_model_fields( + model: "Model", child: "Model", child_model_name: str + ) -> None: + model.__fields__[child_model_name] = ModelField( + name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__, + ) + model.__model_fields__[child_model_name] = ForeignKey( + child.__class__, name=child_model_name, virtual=True + ) diff --git a/orm/models.py b/orm/models.py index c16b21d..6dc7fa3 100644 --- a/orm/models.py +++ b/orm/models.py @@ -32,8 +32,17 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple return pydantic_fields +def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: + child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s" + reverse_name = field.related_name or child_relation_name + relation_name = name.lower().title() + "_" + field.to.get_name() + relationship_manager.add_relation_type( + relation_name, reverse_name, field, table_name + ) + + def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, tablename: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] @@ -46,14 +55,7 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = field_name if isinstance(field, ForeignKey): - child_relation_name = ( - field.to.get_name(title=True) + "_" + name.lower() + "s" - ) - reverse_name = field.related_name or child_relation_name - relation_name = name.lower().title() + "_" + field.to.get_name() - relationship_manager.add_relation_type( - relation_name, reverse_name, field, tablename - ) + register_relation_on_build(table_name, field, name) columns.append(field.get_column(field_name)) return pkname, columns, model_fields @@ -109,8 +111,8 @@ class ModelMetaclass(type): return new_model -class Model(list, metaclass=ModelMetaclass): - # Model inherits from list in order to be treated as +class FakePydantic(list, metaclass=ModelMetaclass): + # FakePydantic inherits from list in order to be treated as # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts __abstract__ = True @@ -125,9 +127,8 @@ class Model(list, metaclass=ModelMetaclass): __database__: databases.Database _orm_relationship_manager: RelationshipManager - objects = qry.QuerySet() - def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False self.values: Optional[BaseModel] = None @@ -145,7 +146,7 @@ class Model(list, metaclass=ModelMetaclass): def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: - if self.is_conversion_to_json_needed(key) and not isinstance(value, str): + if self._is_conversion_to_json_needed(key) and not isinstance(value, str): try: value = json.dumps(value) except TypeError: # pragma no cover @@ -168,7 +169,7 @@ class Model(list, metaclass=ModelMetaclass): item = getattr(self.values, key, None) if ( item is not None - and self.is_conversion_to_json_needed(key) + and self._is_conversion_to_json_needed(key) and isinstance(item, str) ): try: @@ -191,6 +192,79 @@ class Model(list, metaclass=ModelMetaclass): def __repr__(self) -> str: # pragma no cover return self.values.__repr__() + @classmethod + def __get_validators__(cls) -> Callable: # pragma no cover + yield cls.__pydantic_model__.validate + + @classmethod + def get_name(cls, title: bool = False, lower: bool = True) -> str: + name = cls.__name__ + if lower: + name = name.lower() + if title: + name = name.title() + return name + + @property + def pk_column(self) -> sqlalchemy.Column: + return self.__table__.primary_key.columns.values()[0] + + @classmethod + def pk_type(cls): + return cls.__model_fields__[cls.__pkname__].__type__ + + def dict(self) -> Dict: # noqa: A003 + dict_instance = self.values.dict() + for field in self._extract_related_names(): + nested_model = getattr(self, field) + if isinstance(nested_model, list): + dict_instance[field] = [x.dict() for x in nested_model] + else: + dict_instance[field] = ( + nested_model.dict() if nested_model is not None else {} + ) + return dict_instance + + def from_dict(self, value_dict: Dict) -> None: + for key, value in value_dict.items(): + setattr(self, key, value) + + def _is_conversion_to_json_needed(self, column_name: str) -> bool: + return self.__model_fields__.get(column_name).__type__ == pydantic.Json + + def _extract_own_model_fields(self) -> Dict: + related_names = self._extract_related_names() + self_fields = {k: v for k, v in self.dict().items() if k not in related_names} + return self_fields + + @classmethod + def _extract_related_names(cls) -> Set: + related_names = set() + for name, field in cls.__fields__.items(): + if inspect.isclass(field.type_) and issubclass( + field.type_, pydantic.BaseModel + ): + related_names.add(name) + return related_names + + def _extract_model_db_fields(self) -> Dict: + self_fields = self._extract_own_model_fields() + self_fields = { + k: v for k, v in self_fields.items() if k in self.__table__.columns + } + for field in self._extract_related_names(): + if getattr(self, field) is not None: + self_fields[field] = getattr( + getattr(self, field), self.__model_fields__[field].to.__pkname__ + ) + return self_fields + + +class Model(FakePydantic): + __abstract__ = True + + objects = qry.QuerySet() + @classmethod def from_row( cls, @@ -227,30 +301,6 @@ class Model(list, metaclass=ModelMetaclass): return cls(**item) - # @classmethod - # def validate(cls, value: Any) -> 'BaseModel': # pragma no cover - # return cls.__pydantic_model__.validate(value=value) - - @classmethod - def __get_validators__(cls) -> Callable: # pragma no cover - yield cls.__pydantic_model__.validate - - # @classmethod - # def schema(cls, by_alias: bool = True): # pragma no cover - # return cls.__pydantic_model__.schema(by_alias=by_alias) - - @classmethod - def get_name(cls, title: bool = False, lower: bool = True) -> str: - name = cls.__name__ - if lower: - name = name.lower() - if title: - name = name.title() - return name - - def is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.__model_fields__.get(column_name).__type__ == pydantic.Json - @property def pk(self) -> str: return getattr(self.values, self.__pkname__) @@ -259,55 +309,8 @@ class Model(list, metaclass=ModelMetaclass): def pk(self, value: Any) -> None: setattr(self.values, self.__pkname__, value) - @property - def pk_column(self) -> sqlalchemy.Column: - return self.__table__.primary_key.columns.values()[0] - - def dict(self) -> Dict: # noqa: A003 - dict_instance = self.values.dict() - for field in self.extract_related_names(): - nested_model = getattr(self, field) - if isinstance(nested_model, list): - dict_instance[field] = [x.dict() for x in nested_model] - else: - dict_instance[field] = ( - nested_model.dict() if nested_model is not None else {} - ) - return dict_instance - - def from_dict(self, value_dict: Dict) -> None: - for key, value in value_dict.items(): - setattr(self, key, value) - - def extract_own_model_fields(self) -> Dict: - related_names = self.extract_related_names() - self_fields = {k: v for k, v in self.dict().items() if k not in related_names} - return self_fields - - @classmethod - def extract_related_names(cls) -> Set: - related_names = set() - for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel - ): - related_names.add(name) - return related_names - - def extract_model_db_fields(self) -> Dict: - self_fields = self.extract_own_model_fields() - self_fields = { - k: v for k, v in self_fields.items() if k in self.__table__.columns - } - for field in self.extract_related_names(): - if getattr(self, field) is not None: - self_fields[field] = getattr( - getattr(self, field), self.__model_fields__[field].to.__pkname__ - ) - return self_fields - async def save(self) -> int: - self_fields = self.extract_model_db_fields() + self_fields = self._extract_model_db_fields() if self.__model_fields__.get(self.__pkname__).autoincrement: self_fields.pop(self.__pkname__, None) expr = self.__table__.insert() @@ -321,7 +324,7 @@ class Model(list, metaclass=ModelMetaclass): new_values = {**self.dict(), **kwargs} self.from_dict(new_values) - self_fields = self.extract_model_db_fields() + self_fields = self._extract_model_db_fields() self_fields.pop(self.__pkname__) expr = ( self.__table__.update() diff --git a/orm/queryset.py b/orm/queryset.py index 1b5a5dc..23e4a76 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -527,7 +527,7 @@ class QuerySet: del new_kwargs[pkname] # substitute related models with their pk - for field in self.model_cls.extract_related_names(): + for field in self.model_cls._extract_related_names(): if field in new_kwargs and new_kwargs.get(field) is not None: new_kwargs[field] = getattr( new_kwargs.get(field), diff --git a/orm/relations.py b/orm/relations.py index 3232c8c..f541dfe 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -8,7 +8,7 @@ from weakref import proxy from orm.fields import ForeignKey if TYPE_CHECKING: # pragma no cover - from orm.models import Model + from orm.models import FakePydantic, Model def get_table_alias() -> str: @@ -48,7 +48,7 @@ class RelationshipManager: "reverse", table_name, field ) - def deregister(self, model: "Model") -> None: + def deregister(self, model: "FakePydantic") -> None: # print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): @@ -59,8 +59,8 @@ class RelationshipManager: self, parent_name: str, child_name: str, - parent: "Model", - child: "Model", + parent: "FakePydantic", + child: "FakePydantic", virtual: bool = False, ) -> None: parent_id = parent._orm_id @@ -91,13 +91,13 @@ class RelationshipManager: relations_list.append(model) - def contains(self, relations_key: str, instance: "Model") -> bool: + def contains(self, relations_key: str, instance: "FakePydantic") -> bool: if relations_key in self._relations: return instance._orm_id in self._relations[relations_key] return False def get( - self, relations_key: str, instance: "Model" + self, relations_key: str, instance: "FakePydantic" ) -> Union["Model", List["Model"]]: if relations_key in self._relations: if instance._orm_id in self._relations[relations_key]: diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index dfba2da..660f3d4 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -67,6 +67,12 @@ def create_test_database(): metadata.drop_all(engine) +@pytest.mark.asyncio +async def test_wrong_query_foreign_key_type(): + with pytest.raises(RelationshipInstanceError): + Track(title="The Error", album="wrong_pk_type") + + @pytest.mark.asyncio async def test_model_crud(): async with database: