diff --git a/docs/api/models/new-basemodel.md b/docs/api/models/new-basemodel.md index 3c5cdb1..4fb0dfd 100644 --- a/docs/api/models/new-basemodel.md +++ b/docs/api/models/new-basemodel.md @@ -341,11 +341,11 @@ Calls the pydantic method to evaluate pydantic fields. `(None)`: None - + #### \_get\_related\_not\_excluded\_fields ```python - | _get_related_not_excluded_fields(include: Optional[Dict], exclude: Optional[Dict]) -> List + | _get_not_excluded_fields(include: Optional[Dict], exclude: Optional[Dict]) -> List ``` Returns related field names applying on them include and exclude set. diff --git a/ormar/models/helpers/validation.py b/ormar/models/helpers/validation.py index 1ce2a9f..617381f 100644 --- a/ormar/models/helpers/validation.py +++ b/ormar/models/helpers/validation.py @@ -131,7 +131,7 @@ def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> D :type model: Type["Model"] :param relation_map: dict with relations to follow :type relation_map: Optional[Dict] - :return: + :return: dict with example values :rtype: Dict[str, int] """ example: Dict[str, Any] = dict() @@ -141,13 +141,9 @@ def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> D else translate_list_to_dict(model._iterate_related_models()) ) for name, field in model.Meta.model_fields.items(): - if not field.is_relation: - is_bytes_str = field.__type__ == bytes and field.represent_as_base64_str - example[name] = field.__sample__ if not is_bytes_str else "string" - elif isinstance(relation_map, dict) and name in relation_map: - example[name] = get_nested_model_example( - name=name, field=field, relation_map=relation_map - ) + populates_sample_fields_values( + example=example, name=name, field=field, relation_map=relation_map + ) to_exclude = {name for name in model.Meta.model_fields} pydantic_repr = generate_pydantic_example(pydantic_model=model, exclude=to_exclude) example.update(pydantic_repr) @@ -155,6 +151,30 @@ def generate_model_example(model: Type["Model"], relation_map: Dict = None) -> D return example +def populates_sample_fields_values( + example: Dict[str, Any], name: str, field: BaseField, relation_map: Dict = None +) -> None: + """ + Iterates the field and sets fields to sample values + + :param field: ormar field + :type field: BaseField + :param name: name of the field + :type name: str + :param example: example dict + :type example: Dict[str, Any] + :param relation_map: dict with relations to follow + :type relation_map: Optional[Dict] + """ + if not field.is_relation: + is_bytes_str = field.__type__ == bytes and field.represent_as_base64_str + example[name] = field.__sample__ if not is_bytes_str else "string" + elif isinstance(relation_map, dict) and name in relation_map: + example[name] = get_nested_model_example( + name=name, field=field, relation_map=relation_map + ) + + def get_nested_model_example( name: str, field: "BaseField", relation_map: Dict ) -> Union[List, Dict]: diff --git a/ormar/models/mixins/__init__.py b/ormar/models/mixins/__init__.py index 2a64e6b..e605ee8 100644 --- a/ormar/models/mixins/__init__.py +++ b/ormar/models/mixins/__init__.py @@ -8,6 +8,7 @@ from ormar.models.mixins.alias_mixin import AliasMixin from ormar.models.mixins.excludable_mixin import ExcludableMixin from ormar.models.mixins.merge_mixin import MergeModelMixin from ormar.models.mixins.prefetch_mixin import PrefetchQueryMixin +from ormar.models.mixins.pydantic_mixin import PydanticMixin from ormar.models.mixins.save_mixin import SavePrepareMixin __all__ = [ @@ -16,4 +17,5 @@ __all__ = [ "PrefetchQueryMixin", "SavePrepareMixin", "ExcludableMixin", + "PydanticMixin", ] diff --git a/ormar/models/mixins/pydantic_mixin.py b/ormar/models/mixins/pydantic_mixin.py new file mode 100644 index 0000000..76af068 --- /dev/null +++ b/ormar/models/mixins/pydantic_mixin.py @@ -0,0 +1,95 @@ +from typing import Any, Callable, Dict, List, Set, TYPE_CHECKING, Type, Union, cast + +import pydantic +from pydantic.fields import ModelField + +from ormar.models.mixins.relation_mixin import RelationMixin # noqa: I100, I202 +from ormar.queryset.utils import translate_list_to_dict + + +class PydanticMixin(RelationMixin): + if TYPE_CHECKING: # pragma: no cover + __fields__: Dict[str, ModelField] + _skip_ellipsis: Callable + _get_not_excluded_fields: Callable + + @classmethod + def get_pydantic( + cls, *, include: Union[Set, Dict] = None, exclude: Union[Set, Dict] = None, + ) -> Type[pydantic.BaseModel]: + """ + Returns a pydantic model out of ormar model. + + Converts also nested ormar models into pydantic models. + + Can be used to fully exclude certain fields in fastapi response and requests. + + :param include: fields of own and nested models to include + :type include: Union[Set, Dict, None] + :param exclude: fields of own and nested models to exclude + :type exclude: Union[Set, Dict, None] + """ + relation_map = translate_list_to_dict(cls._iterate_related_models()) + + return cls._convert_ormar_to_pydantic( + include=include, exclude=exclude, relation_map=relation_map + ) + + @classmethod + def _convert_ormar_to_pydantic( + cls, + relation_map: Dict[str, Any], + include: Union[Set, Dict] = None, + exclude: Union[Set, Dict] = None, + ) -> Type[pydantic.BaseModel]: + if include and isinstance(include, Set): + include = translate_list_to_dict(include) + if exclude and isinstance(exclude, Set): + exclude = translate_list_to_dict(exclude) + fields_dict: Dict[str, Any] = dict() + defaults: Dict[str, Any] = dict() + fields_to_process = cls._get_not_excluded_fields( + fields={*cls.Meta.model_fields.keys()}, include=include, exclude=exclude + ) + for name in fields_to_process: + field = cls._determine_pydantic_field_type( + name=name, + defaults=defaults, + include=include, + exclude=exclude, + relation_map=relation_map, + ) + if field is not None: + fields_dict[name] = field + model = type( + cls.__name__, + (pydantic.BaseModel,), + {"__annotations__": fields_dict, **defaults}, + ) + return cast(Type[pydantic.BaseModel], model) + + @classmethod + def _determine_pydantic_field_type( + cls, + name: str, + defaults: Dict, + include: Union[Set, Dict, None], + exclude: Union[Set, Dict, None], + relation_map: Dict[str, Any], + ) -> Any: + field = cls.Meta.model_fields[name] + if field.is_relation and name in relation_map: # type: ignore + target = field.to._convert_ormar_to_pydantic( + include=cls._skip_ellipsis(include, name), + exclude=cls._skip_ellipsis(exclude, name), + relation_map=cls._skip_ellipsis( + relation_map, field, default_return=dict() + ), + ) + if field.is_multi or field.virtual: + return List[target] # type: ignore + return target + elif not field.is_relation: + defaults[name] = cls.__fields__[name].field_info + return field.__type__ + return None diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index fd1b8e6..bcbd685 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -2,12 +2,17 @@ from ormar.models.mixins import ( ExcludableMixin, MergeModelMixin, PrefetchQueryMixin, + PydanticMixin, SavePrepareMixin, ) class ModelTableProxy( - PrefetchQueryMixin, MergeModelMixin, SavePrepareMixin, ExcludableMixin + PrefetchQueryMixin, + MergeModelMixin, + SavePrepareMixin, + ExcludableMixin, + PydanticMixin, ): """ Used to combine all mixins with different set of functionalities. diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index c2a9476..72033c6 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -454,8 +454,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass super().update_forward_refs(**localns) cls.Meta.requires_ref_update = False - def _get_related_not_excluded_fields( - self, include: Optional[Dict], exclude: Optional[Dict], + @staticmethod + def _get_not_excluded_fields( + fields: Union[List, Set], include: Optional[Dict], exclude: Optional[Dict], ) -> List: """ Returns related field names applying on them include and exclude set. @@ -467,7 +468,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :return: :rtype: List of fields with relations that is not excluded """ - fields = [field for field in self.extract_related_names()] + fields = [*fields] if not isinstance(fields, list) else fields if include: fields = [field for field in fields if field in include] if exclude: @@ -519,8 +520,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass continue return result + @classmethod def _skip_ellipsis( - self, items: Union[Set, Dict, None], key: str, default_return: Any = None + cls, items: Union[Set, Dict, None], key: str, default_return: Any = None ) -> Union[Set, Dict, None]: """ Helper to traverse the include/exclude dictionaries. @@ -534,10 +536,11 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :return: nested value of the items :rtype: Union[Set, Dict, None] """ - result = self.get_child(items, key) + result = cls.get_child(items, key) return result if result is not Ellipsis else default_return - def _convert_all(self, items: Union[Set, Dict, None]) -> Union[Set, Dict, None]: + @staticmethod + def _convert_all(items: Union[Set, Dict, None]) -> Union[Set, Dict, None]: """ Helper to convert __all__ pydantic special index to ormar which does not support index based exclusions. @@ -573,8 +576,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass :return: current model dict with child models converted to dictionaries :rtype: Dict """ - - fields = self._get_related_not_excluded_fields(include=include, exclude=exclude) + fields = self._get_not_excluded_fields( + fields=self.extract_related_names(), include=include, exclude=exclude + ) for field in fields: if not relation_map or field not in relation_map: diff --git a/ormar/models/quick_access_views.py b/ormar/models/quick_access_views.py index 89cbf08..a822375 100644 --- a/ormar/models/quick_access_views.py +++ b/ormar/models/quick_access_views.py @@ -26,7 +26,7 @@ quick_access_set = { "_extract_nested_models_from_list", "_extract_own_model_fields", "_extract_related_model_instead_of_field", - "_get_related_not_excluded_fields", + "_get_not_excluded_fields", "_get_value", "_init_private_attributes", "_is_conversion_to_json_needed", diff --git a/tests/test_fastapi/test_inheritance_concrete_fastapi.py b/tests/test_fastapi/test_inheritance_concrete_fastapi.py index e547ad2..721fdf3 100644 --- a/tests/test_fastapi/test_inheritance_concrete_fastapi.py +++ b/tests/test_fastapi/test_inheritance_concrete_fastapi.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient from tests.settings import DATABASE_URL -from tests.test_inheritance.test_inheritance_concrete import ( # type: ignore +from tests.test_inheritance_and_pydantic_generation.test_inheritance_concrete import ( # type: ignore Category, Subject, Person, diff --git a/tests/test_fastapi/test_inheritance_mixins_fastapi.py b/tests/test_fastapi/test_inheritance_mixins_fastapi.py index 681f5ef..1f74de6 100644 --- a/tests/test_fastapi/test_inheritance_mixins_fastapi.py +++ b/tests/test_fastapi/test_inheritance_mixins_fastapi.py @@ -6,7 +6,7 @@ from fastapi import FastAPI from starlette.testclient import TestClient from tests.settings import DATABASE_URL -from tests.test_inheritance.test_inheritance_mixins import Category, Subject, metadata, db as database # type: ignore +from tests.test_inheritance_and_pydantic_generation.test_inheritance_mixins import Category, Subject, metadata, db as database # type: ignore app = FastAPI() app.state.database = database diff --git a/tests/test_inheritance/__init__.py b/tests/test_inheritance_and_pydantic_generation/__init__.py similarity index 100% rename from tests/test_inheritance/__init__.py rename to tests/test_inheritance_and_pydantic_generation/__init__.py diff --git a/tests/test_inheritance_and_pydantic_generation/test_geting_the_pydantic_models.py b/tests/test_inheritance_and_pydantic_generation/test_geting_the_pydantic_models.py new file mode 100644 index 0000000..46e677f --- /dev/null +++ b/tests/test_inheritance_and_pydantic_generation/test_geting_the_pydantic_models.py @@ -0,0 +1,118 @@ +from typing import List, Optional + +import databases +import pydantic +import sqlalchemy +from pydantic import ConstrainedStr + +import ormar +from tests.settings import DATABASE_URL + +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) + + +class Category(ormar.Model): + class Meta: + tablename = "categories" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Item(ormar.Model): + class Meta: + tablename = "items" + metadata = metadata + database = database + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100, default="test") + category: Optional[Category] = ormar.ForeignKey(Category, nullable=True) + + +def test_getting_pydantic_model(): + PydanticCategory = Category.get_pydantic() + assert issubclass(PydanticCategory, pydantic.BaseModel) + assert {*PydanticCategory.__fields__.keys()} == {"items", "id", "name"} + + assert not PydanticCategory.__fields__["id"].required + assert PydanticCategory.__fields__["id"].outer_type_ == int + assert PydanticCategory.__fields__["id"].default is None + + assert PydanticCategory.__fields__["name"].required + assert issubclass(PydanticCategory.__fields__["name"].outer_type_, ConstrainedStr) + assert PydanticCategory.__fields__["name"].default is None + + PydanticItem = PydanticCategory.__fields__["items"].type_ + assert PydanticCategory.__fields__["items"].outer_type_ == List[PydanticItem] + assert issubclass(PydanticItem, pydantic.BaseModel) + assert not PydanticItem.__fields__["name"].required + assert PydanticItem.__fields__["name"].default == "test" + assert issubclass(PydanticItem.__fields__["name"].outer_type_, ConstrainedStr) + assert "category" not in PydanticItem.__fields__ + + +def test_getting_pydantic_model_include(): + PydanticCategory = Category.get_pydantic(include={"id", "name"}) + assert len(PydanticCategory.__fields__) == 2 + assert "items" not in PydanticCategory.__fields__ + + +def test_getting_pydantic_model_nested_include_set(): + PydanticCategory = Category.get_pydantic(include={"id", "items__id"}) + assert len(PydanticCategory.__fields__) == 2 + assert "name" not in PydanticCategory.__fields__ + PydanticItem = PydanticCategory.__fields__["items"].type_ + assert len(PydanticItem.__fields__) == 1 + assert "id" in PydanticItem.__fields__ + + +def test_getting_pydantic_model_nested_include_dict(): + PydanticCategory = Category.get_pydantic(include={"id": ..., "items": {"id"}}) + assert len(PydanticCategory.__fields__) == 2 + assert "name" not in PydanticCategory.__fields__ + PydanticItem = PydanticCategory.__fields__["items"].type_ + assert len(PydanticItem.__fields__) == 1 + assert "id" in PydanticItem.__fields__ + + +def test_getting_pydantic_model_nested_include_nested_dict(): + PydanticCategory = Category.get_pydantic(include={"id": ..., "items": {"id": ...}}) + assert len(PydanticCategory.__fields__) == 2 + assert "name" not in PydanticCategory.__fields__ + PydanticItem = PydanticCategory.__fields__["items"].type_ + assert len(PydanticItem.__fields__) == 1 + assert "id" in PydanticItem.__fields__ + + +def test_getting_pydantic_model_include_exclude(): + PydanticCategory = Category.get_pydantic( + include={"id": ..., "items": {"id", "name"}}, exclude={"items__name"} + ) + assert len(PydanticCategory.__fields__) == 2 + assert "name" not in PydanticCategory.__fields__ + PydanticItem = PydanticCategory.__fields__["items"].type_ + assert len(PydanticItem.__fields__) == 1 + assert "id" in PydanticItem.__fields__ + + +def test_getting_pydantic_model_exclude(): + PydanticItem = Item.get_pydantic(exclude={"category__name"}) + assert len(PydanticItem.__fields__) == 3 + assert "category" in PydanticItem.__fields__ + PydanticCategory = PydanticItem.__fields__["category"].type_ + assert len(PydanticCategory.__fields__) == 1 + assert "name" not in PydanticCategory.__fields__ + + +def test_getting_pydantic_model_exclude_dict(): + PydanticItem = Item.get_pydantic(exclude={"id": ..., "category": {"name"}}) + assert len(PydanticItem.__fields__) == 2 + assert "category" in PydanticItem.__fields__ + assert "id" not in PydanticItem.__fields__ + PydanticCategory = PydanticItem.__fields__["category"].type_ + assert len(PydanticCategory.__fields__) == 1 + assert "name" not in PydanticCategory.__fields__ diff --git a/tests/test_inheritance/test_inheritance_concrete.py b/tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py similarity index 100% rename from tests/test_inheritance/test_inheritance_concrete.py rename to tests/test_inheritance_and_pydantic_generation/test_inheritance_concrete.py diff --git a/tests/test_inheritance/test_inheritance_mixins.py b/tests/test_inheritance_and_pydantic_generation/test_inheritance_mixins.py similarity index 100% rename from tests/test_inheritance/test_inheritance_mixins.py rename to tests/test_inheritance_and_pydantic_generation/test_inheritance_mixins.py