add fixes for fastapi model clones, add functionality to add and remove models to relation, add relation proxy, fix all tests, adding values also to pydantic model __dict__some refactors
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
from ormar.models.newbasemodel import NewBaseModel
|
||||
from ormar.models.model import Model
|
||||
from ormar.models.metaclass import expand_reverse_relationships
|
||||
|
||||
__all__ = ["NewBaseModel", "Model"]
|
||||
__all__ = ["NewBaseModel", "Model", "expand_reverse_relationships"]
|
||||
|
||||
@ -29,17 +29,8 @@ class ModelMeta:
|
||||
alias_manager: AliasManager
|
||||
|
||||
|
||||
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
|
||||
child_relation_name = (
|
||||
field.to.get_name(title=True)
|
||||
+ "_"
|
||||
+ (field.related_name or (name.lower() + "s"))
|
||||
)
|
||||
reverse_name = child_relation_name
|
||||
relation_name = name.lower().title() + "_" + field.to.get_name()
|
||||
relationship_manager.add_relation_type(
|
||||
relation_name, reverse_name, field, table_name
|
||||
)
|
||||
def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
|
||||
relationship_manager.add_relation_type(field, table_name)
|
||||
|
||||
|
||||
def expand_reverse_relationships(model: Type["Model"]) -> None:
|
||||
@ -64,15 +55,10 @@ def register_reverse_model_fields(
|
||||
|
||||
|
||||
def sqlalchemy_columns_from_model_fields(
|
||||
name: str, object_dict: Dict, table_name: str
|
||||
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
|
||||
model_fields: Dict, table_name: str
|
||||
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
|
||||
columns = []
|
||||
pkname = None
|
||||
model_fields = {
|
||||
field_name: field
|
||||
for field_name, field in object_dict["__annotations__"].items()
|
||||
if issubclass(field, BaseField)
|
||||
}
|
||||
for field_name, field in model_fields.items():
|
||||
if field.primary_key:
|
||||
if pkname is not None:
|
||||
@ -83,9 +69,9 @@ def sqlalchemy_columns_from_model_fields(
|
||||
if not field.pydantic_only:
|
||||
columns.append(field.get_column(field_name))
|
||||
if issubclass(field, ForeignKeyField):
|
||||
register_relation_on_build(table_name, field, name)
|
||||
register_relation_on_build(table_name, field)
|
||||
|
||||
return pkname, columns, model_fields
|
||||
return pkname, columns
|
||||
|
||||
|
||||
def populate_pydantic_default_values(attrs: Dict) -> Dict:
|
||||
@ -125,21 +111,29 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
attrs["__annotations__"] = annotations
|
||||
attrs = populate_pydantic_default_values(attrs)
|
||||
|
||||
attrs["__module__"] = attrs["__module__"] or bases[0].__module__
|
||||
attrs["__annotations__"] = (
|
||||
attrs["__annotations__"] or bases[0].__annotations__
|
||||
)
|
||||
|
||||
tablename = name.lower() + "s"
|
||||
new_model.Meta.tablename = new_model.Meta.tablename or tablename
|
||||
|
||||
# sqlalchemy table creation
|
||||
|
||||
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
|
||||
name, attrs, new_model.Meta.tablename
|
||||
)
|
||||
model_fields = {
|
||||
field_name: field
|
||||
for field_name, field in attrs["__annotations__"].items()
|
||||
if issubclass(field, BaseField)
|
||||
}
|
||||
|
||||
if hasattr(new_model.Meta, "model_fields") and not pkname:
|
||||
model_fields = new_model.Meta.model_fields
|
||||
for fieldname, field in new_model.Meta.model_fields.items():
|
||||
if field.primary_key:
|
||||
pkname = fieldname
|
||||
if hasattr(new_model.Meta, "columns"):
|
||||
columns = new_model.Meta.table.columns
|
||||
pkname = new_model.Meta.pkname
|
||||
else:
|
||||
pkname, columns = sqlalchemy_columns_from_model_fields(
|
||||
model_fields, new_model.Meta.tablename
|
||||
)
|
||||
|
||||
if not hasattr(new_model.Meta, "table"):
|
||||
new_model.Meta.table = sqlalchemy.Table(
|
||||
@ -153,10 +147,11 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
|
||||
raise ModelDefinitionError("Table has to have a primary key.")
|
||||
|
||||
new_model.Meta.model_fields = model_fields
|
||||
expand_reverse_relationships(new_model)
|
||||
|
||||
new_model = super().__new__( # type: ignore
|
||||
mcs, name, bases, attrs
|
||||
)
|
||||
expand_reverse_relationships(new_model)
|
||||
|
||||
new_model.Meta.alias_manager = relationship_manager
|
||||
new_model.objects = QuerySet(new_model)
|
||||
|
||||
@ -69,7 +69,8 @@ class Model(NewBaseModel):
|
||||
|
||||
async def save(self) -> "Model":
|
||||
self_fields = self._extract_model_db_fields()
|
||||
if self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
|
||||
|
||||
if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement:
|
||||
self_fields.pop(self.Meta.pkname, None)
|
||||
expr = self.Meta.table.insert()
|
||||
expr = expr.values(**self_fields)
|
||||
@ -77,7 +78,7 @@ class Model(NewBaseModel):
|
||||
setattr(self, self.Meta.pkname, item_id)
|
||||
return self
|
||||
|
||||
async def update(self, **kwargs: Any) -> int:
|
||||
async def update(self, **kwargs: Any) -> "Model":
|
||||
if kwargs:
|
||||
new_values = {**self.dict(), **kwargs}
|
||||
self.from_dict(new_values)
|
||||
@ -89,8 +90,8 @@ class Model(NewBaseModel):
|
||||
.values(**self_fields)
|
||||
.where(self.pk_column == getattr(self, self.Meta.pkname))
|
||||
)
|
||||
result = await self.Meta.database.execute(expr)
|
||||
return result
|
||||
await self.Meta.database.execute(expr)
|
||||
return self
|
||||
|
||||
async def delete(self) -> int:
|
||||
expr = self.Meta.table.delete()
|
||||
|
||||
@ -24,7 +24,6 @@ class ModelTableProxy:
|
||||
|
||||
@classmethod
|
||||
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
|
||||
model_dict = copy.deepcopy(model_dict)
|
||||
for field in cls._extract_related_names():
|
||||
if field in model_dict and model_dict.get(field) is not None:
|
||||
target_field = cls.Meta.model_fields[field]
|
||||
@ -76,10 +75,19 @@ class ModelTableProxy:
|
||||
}
|
||||
for field in self._extract_db_related_names():
|
||||
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
|
||||
if getattr(self, field) is not None:
|
||||
self_fields[field] = getattr(getattr(self, field), target_pk_name)
|
||||
target_field = getattr(self, field)
|
||||
self_fields[field] = getattr(target_field, target_pk_name, None)
|
||||
return self_fields
|
||||
|
||||
@staticmethod
|
||||
def resolve_relation_name(item: "Model", related: "Model"):
|
||||
for name, field in item.Meta.model_fields.items():
|
||||
if issubclass(field, ForeignKeyField):
|
||||
# fastapi is creating clones of response model that's why it can be a subclass
|
||||
# of the original one so we need to compare Meta too
|
||||
if field.to == related.__class__ or field.to.Meta == related.Meta:
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
|
||||
merged_rows = []
|
||||
|
||||
@ -71,9 +71,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
pk_only = kwargs.pop("__pk_only__", False)
|
||||
if "pk" in kwargs:
|
||||
kwargs[self.Meta.pkname] = kwargs.pop("pk")
|
||||
# build the models to set them and validate but don't register
|
||||
kwargs = {
|
||||
k: self._convert_json(
|
||||
k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps"
|
||||
k,
|
||||
self.Meta.model_fields[k].expand_relationship(
|
||||
v, self, to_register=False
|
||||
),
|
||||
"dumps",
|
||||
)
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
@ -85,13 +90,20 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
object.__setattr__(self, "__dict__", values)
|
||||
object.__setattr__(self, "__fields_set__", fields_set)
|
||||
|
||||
# register the related models after initialization
|
||||
for related in self._extract_related_names():
|
||||
self.Meta.model_fields[related].expand_relationship(
|
||||
kwargs.get(related), self, to_register=True
|
||||
)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in self.__slots__:
|
||||
object.__setattr__(self, name, value)
|
||||
elif name == "pk":
|
||||
object.__setattr__(self, self.Meta.pkname, value)
|
||||
elif name in self._orm:
|
||||
self.Meta.model_fields[name].expand_relationship(value, self)
|
||||
model = self.Meta.model_fields[name].expand_relationship(value, self)
|
||||
self.__dict__[name] = model
|
||||
else:
|
||||
value = (
|
||||
self._convert_json(name, value, "dumps")
|
||||
@ -113,19 +125,13 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
return value
|
||||
return super().__getattribute__(item)
|
||||
|
||||
# def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
|
||||
# return self._extract_related_model_instead_of_field(item)
|
||||
|
||||
def _extract_related_model_instead_of_field(
|
||||
self, item: str
|
||||
) -> Optional[Union["Model", List["Model"]]]:
|
||||
# relation_key = self.get_name(title=True) + "_" + item
|
||||
if item in self._orm:
|
||||
return self._orm.get(item)
|
||||
|
||||
def __same__(self, other: "Model") -> bool:
|
||||
if self.__class__ != other.__class__: # pragma no cover
|
||||
return False
|
||||
return (
|
||||
self._orm_id == other._orm_id
|
||||
or self.__dict__ == other.__dict__
|
||||
@ -137,8 +143,6 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
name = cls.__name__
|
||||
if lower:
|
||||
name = name.lower()
|
||||
if title:
|
||||
name = name.title()
|
||||
return name
|
||||
|
||||
@property
|
||||
@ -149,6 +153,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
def pk_type(cls) -> Any:
|
||||
return cls.Meta.model_fields[cls.Meta.pkname].__type__
|
||||
|
||||
def remove(self, name: "Model"):
|
||||
self._orm.remove_parent(self, name)
|
||||
|
||||
def dict( # noqa A003
|
||||
self,
|
||||
*,
|
||||
@ -176,14 +183,23 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
|
||||
if self.Meta.model_fields[field].virtual and nested:
|
||||
continue
|
||||
if isinstance(nested_model, list):
|
||||
dict_instance[field] = [x.dict(nested=True) for x in nested_model]
|
||||
result = []
|
||||
for model in nested_model:
|
||||
try:
|
||||
result.append(model.dict(nested=True))
|
||||
except ReferenceError: # pragma no cover
|
||||
continue
|
||||
dict_instance[field] = result
|
||||
elif nested_model is not None:
|
||||
dict_instance[field] = nested_model.dict(nested=True)
|
||||
else:
|
||||
dict_instance[field] = None
|
||||
return dict_instance
|
||||
|
||||
def from_dict(self, value_dict: Dict) -> None:
|
||||
def from_dict(self, value_dict: Dict) -> "Model":
|
||||
for key, value in value_dict.items():
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]:
|
||||
if not self._is_conversion_to_json_needed(column_name):
|
||||
|
||||
Reference in New Issue
Block a user