import json import uuid from typing import ( AbstractSet, Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Type, TypeVar, Union, ) import databases import pydantic import sqlalchemy from pydantic import BaseModel import ormar # noqa I100 from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.modelproxy import ModelTableProxy from ormar.relations.alias_manager import AliasManager from ormar.relations.relation_manager import RelationsManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model IntStr = Union[int, str] DictStrAny = Dict[str, Any] AbstractSetIntStr = AbstractSet[IntStr] MappingIntStrAny = Mapping[IntStr, Any] class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass): __slots__ = ("_orm_id", "_orm_saved", "_orm") if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] __table__: sqlalchemy.Table __fields__: Dict[str, pydantic.fields.ModelField] __pydantic_model__: Type[BaseModel] __pkname__: str __tablename__: str __metadata__: sqlalchemy.MetaData __database__: databases.Database _orm_relationship_manager: AliasManager _orm: RelationsManager Meta: ModelMeta # noinspection PyMissingConstructor def __init__(self, *args: Any, **kwargs: Any) -> None: object.__setattr__(self, "_orm_id", uuid.uuid4().hex) object.__setattr__(self, "_orm_saved", False) object.__setattr__( self, "_orm", RelationsManager( related_fields=[ field for name, field in self.Meta.model_fields.items() if issubclass(field, ForeignKeyField) ], owner=self, ), ) pk_only = kwargs.pop("__pk_only__", False) if "pk" in kwargs: kwargs[self.Meta.pkname] = kwargs.pop("pk") # build the models to set them and validate but don't register kwargs = { k: self._convert_json( k, self.Meta.model_fields[k].expand_relationship( v, self, to_register=False ), "dumps", ) for k, v in kwargs.items() } values, fields_set, validation_error = pydantic.validate_model(self, kwargs) if validation_error and not pk_only: raise validation_error object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__fields_set__", fields_set) # register the related models after initialization for related in self.extract_related_names(): self.Meta.model_fields[related].expand_relationship( kwargs.get(related), self, to_register=True ) def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 if name in self.__slots__: object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) elif name in self._orm: model = self.Meta.model_fields[name].expand_relationship(value, self) if isinstance(self.__dict__.get(name), list): self.__dict__[name].append(model) else: self.__dict__[name] = model else: value = ( self._convert_json(name, value, "dumps") if name in self.__fields__ else value ) super().__setattr__(name, value) def __getattribute__(self, item: str) -> Any: if item in ("_orm_id", "_orm_saved", "_orm", "__fields__"): return object.__getattribute__(self, item) if item != "extract_related_names" and item in self.extract_related_names(): return self._extract_related_model_instead_of_field(item) if item == "pk": return self.__dict__.get(self.Meta.pkname, None) if item != "__fields__" and item in self.__fields__: value = self.__dict__.get(item, None) value = self._convert_json(item, value, "loads") return value return super().__getattribute__(item) def _extract_related_model_instead_of_field( self, item: str ) -> Optional[Union["Model", List["Model"]]]: if item in self._orm: return self._orm.get(item) def __eq__(self, other: "Model") -> bool: if isinstance(other, NewBaseModel): return self.__same__(other) return super().__eq__(other) # pragma no cover def __same__(self, other: "Model") -> bool: return ( self._orm_id == other._orm_id or self.dict() == other.dict() or (self.pk == other.pk and self.pk is not None) ) @classmethod def get_name(cls, lower: bool = True) -> str: name = cls.__name__ if lower: name = name.lower() return name @property def pk_column(self) -> sqlalchemy.Column: return self.Meta.table.primary_key.columns.values()[0] @classmethod def pk_type(cls) -> Any: return cls.Meta.model_fields[cls.Meta.pkname].__type__ def remove(self, name: "Model") -> None: self._orm.remove_parent(self, name) def dict( # noqa A003 self, *, include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, by_alias: bool = False, skip_defaults: bool = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, nested: bool = False ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, exclude=self._exclude_related_names_not_required(nested), by_alias=by_alias, skip_defaults=skip_defaults, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) for field in self.extract_related_names(): nested_model = getattr(self, field) if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list): result = [] for model in nested_model: try: result.append(model.dict(nested=True)) except ReferenceError: # pragma no cover continue dict_instance[field] = result elif nested_model is not None: dict_instance[field] = nested_model.dict(nested=True) else: dict_instance[field] = None return dict_instance def from_dict(self, value_dict: Dict) -> "Model": for key, value in value_dict.items(): setattr(self, key, value) return self def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: if not self._is_conversion_to_json_needed(column_name): return value condition = ( isinstance(value, str) if op == "loads" else not isinstance(value, str) ) operand = json.loads if op == "loads" else json.dumps if condition: try: return operand(value) except TypeError: # pragma no cover pass return value def _is_conversion_to_json_needed(self, column_name: str) -> bool: return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json