diff --git a/ormar/models/mixins/pydantic_mixin.py b/ormar/models/mixins/pydantic_mixin.py index 1ecdd89..5d186bb 100644 --- a/ormar/models/mixins/pydantic_mixin.py +++ b/ormar/models/mixins/pydantic_mixin.py @@ -22,6 +22,9 @@ from ormar.queryset.utils import translate_list_to_dict class PydanticMixin(RelationMixin): + + __cache__: Dict[str, Type[pydantic.BaseModel]] = {} + if TYPE_CHECKING: # pragma: no cover __fields__: Dict[str, ModelField] _skip_ellipsis: Callable @@ -68,6 +71,11 @@ class PydanticMixin(RelationMixin): fields_to_process.sort( key=lambda x: list(cls.Meta.model_fields.keys()).index(x) ) + + cache_key = f"{cls.__name__}_{str(include)}_{str(exclude)}" + if cache_key in cls.__cache__: + return cls.__cache__[cache_key] + for name in fields_to_process: field = cls._determine_pydantic_field_type( name=name, @@ -85,6 +93,7 @@ class PydanticMixin(RelationMixin): ) model = cast(Type[pydantic.BaseModel], model) cls._copy_field_validators(model=model) + cls.__cache__[cache_key] = model return model @classmethod diff --git a/tests/test_inheritance_and_pydantic_generation/test_geting_pydantic_models.py b/tests/test_inheritance_and_pydantic_generation/test_geting_pydantic_models.py index 06401b3..146c87f 100644 --- a/tests/test_inheritance_and_pydantic_generation/test_geting_pydantic_models.py +++ b/tests/test_inheritance_and_pydantic_generation/test_geting_pydantic_models.py @@ -60,6 +60,7 @@ class MutualB(ormar.Model): tablename = "mutual_b" id: int = ormar.Integer(primary_key=True) + name = ormar.String(max_length=100, default="test") mutual_a = ormar.ForeignKey(MutualA, related_name="mutuals_b") @@ -183,7 +184,26 @@ def test_getting_pydantic_model_self_ref(): assert len(InnerSelf2.__fields__) == 2 assert set(InnerSelf2.__fields__.keys()) == {"id", "name"} - assert InnerSelf2 != InnerSelf + +def test_getting_pydantic_model_self_ref_exclude(): + PydanticSelfRef = SelfRef.get_pydantic(exclude={"children": {"name"}}) + assert len(PydanticSelfRef.__fields__) == 4 + assert set(PydanticSelfRef.__fields__.keys()) == { + "id", + "name", + "parent", + "children", + } + + InnerSelf = PydanticSelfRef.__fields__["parent"].type_ + assert len(InnerSelf.__fields__) == 2 + assert set(InnerSelf.__fields__.keys()) == {"id", "name"} + + PydanticSelfRefChildren = PydanticSelfRef.__fields__["children"].type_ + assert len(PydanticSelfRefChildren.__fields__) == 1 + assert set(PydanticSelfRefChildren.__fields__.keys()) == {"id"} + assert PydanticSelfRef != PydanticSelfRefChildren + assert InnerSelf != PydanticSelfRefChildren def test_getting_pydantic_model_mutual_rels(): @@ -193,10 +213,23 @@ def test_getting_pydantic_model_mutual_rels(): MutualB1 = MutualAPydantic.__fields__["mutual_b"].type_ MutualB2 = MutualAPydantic.__fields__["mutuals_b"].type_ - assert MutualB1 != MutualB2 + assert len(MutualB1.__fields__) == 2 + assert set(MutualB1.__fields__.keys()) == {"id", "name"} + assert len(MutualB2.__fields__) == 2 + assert set(MutualB2.__fields__.keys()) == {"id", "name"} + assert MutualB1 == MutualB2 + + +def test_getting_pydantic_model_mutual_rels_exclude(): + MutualAPydantic = MutualA.get_pydantic(exclude={"mutual_b": {"name"}}) + assert len(MutualAPydantic.__fields__) == 3 + assert set(MutualAPydantic.__fields__.keys()) == {"id", "mutual_b", "mutuals_b"} + + MutualB1 = MutualAPydantic.__fields__["mutual_b"].type_ + MutualB2 = MutualAPydantic.__fields__["mutuals_b"].type_ assert len(MutualB1.__fields__) == 1 - assert "id" in MutualB1.__fields__ - - assert len(MutualB2.__fields__) == 1 - assert "id" in MutualB2.__fields__ + assert set(MutualB1.__fields__.keys()) == {"id"} + assert len(MutualB2.__fields__) == 2 + assert set(MutualB2.__fields__.keys()) == {"id", "name"} + assert MutualB1 != MutualB2