diff --git a/.coverage b/.coverage index 5fd9217..367e452 100644 Binary files a/.coverage and b/.coverage differ diff --git a/orm/models/fakepydantic.py b/orm/models/fakepydantic.py index 55fe51d..5e1d6b4 100644 --- a/orm/models/fakepydantic.py +++ b/orm/models/fakepydantic.py @@ -11,6 +11,7 @@ from typing import ( TYPE_CHECKING, Type, TypeVar, + Union, ) import databases @@ -62,15 +63,10 @@ class FakePydantic(list, metaclass=ModelMetaclass): def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: - if self._is_conversion_to_json_needed(key) and not isinstance(value, str): - try: - value = json.dumps(value) - except TypeError: # pragma no cover - pass - + value = self._convert_json(key, value, op="dumps") value = self.__model_fields__[key].expand_relationship(value, self) - relation_key = self.__class__.__name__.title() + "_" + key + relation_key = self.get_name(title=True) + "_" + key if not self._orm_relationship_manager.contains(relation_key, self): setattr(self.values, key, value) else: @@ -78,20 +74,12 @@ class FakePydantic(list, metaclass=ModelMetaclass): def __getattribute__(self, key: str) -> Any: if key != "__fields__" and key in self.__fields__: - relation_key = self.__class__.__name__.title() + "_" + key + relation_key = self.get_name(title=True) + "_" + key if self._orm_relationship_manager.contains(relation_key, self): return self._orm_relationship_manager.get(relation_key, self) item = getattr(self.values, key, None) - if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) - ): - try: - item = json.loads(item) - except TypeError: # pragma no cover - pass + item = self._convert_json(key, item, op="loads") return item return super().__getattribute__(key) @@ -145,6 +133,23 @@ class FakePydantic(list, metaclass=ModelMetaclass): for key, value in value_dict.items(): setattr(self, key, value) + def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: + + if not self._is_conversion_to_json_needed(column_name): + return value + + condition = ( + isinstance(value, str) if op == "loads" else not isinstance(value, str) + ) + operand = json.loads if op == "loads" else json.dumps + + if condition: + try: + return operand(value) + except TypeError: # pragma no cover + pass + return value + def _is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @@ -188,7 +193,6 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.__model_fields__.keys(): - # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance( getattr(one, field), orm.Model ): diff --git a/orm/models/metaclass.py b/orm/models/metaclass.py index 4ddb21e..d10806b 100644 --- a/orm/models/metaclass.py +++ b/orm/models/metaclass.py @@ -66,19 +66,21 @@ def register_reverse_model_fields( def sqlalchemy_columns_from_model_fields( name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: - pkname: Optional[str] = None - columns: List[sqlalchemy.Column] = [] - model_fields: Dict[str, BaseField] = {} + columns = [] + pkname = None + model_fields = { + field_name: field + for field_name, field in object_dict.items() + if isinstance(field, BaseField) + } + for field_name, field in model_fields.items(): + if field.primary_key: + pkname = field_name + if not field.pydantic_only: + columns.append(field.get_column(field_name)) + if isinstance(field, ForeignKey): + register_relation_on_build(table_name, field, name) - for field_name, field in object_dict.items(): - if isinstance(field, BaseField): - model_fields[field_name] = field - if not field.pydantic_only: - if field.primary_key: - pkname = field_name - if isinstance(field, ForeignKey): - register_relation_on_build(table_name, field, name) - columns.append(field.get_column(field_name)) return pkname, columns, model_fields diff --git a/orm/relations.py b/orm/relations.py index 7cf5ecf..c7ef8b6 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -49,9 +49,8 @@ class RelationshipManager: ) def deregister(self, model: "FakePydantic") -> None: - # print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): - if model.__class__.__name__.lower() in rel_type.lower(): + if model.get_name() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: del self._relations[rel_type][model._orm_id]