diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 2959f32..854f83d 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -58,6 +58,7 @@ def ForeignKey( # noqa CFQ002 pydantic_only=False, default=None, server_default=None, + __pydantic_model__=to, ) return type("ForeignKey", (ForeignKeyField, BaseField), namespace) diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index d7c57b9..8bb0168 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -41,7 +41,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -50,11 +50,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -65,7 +65,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -73,10 +73,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -91,7 +91,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, name=model.get_name(), ondelete="CASCADE" @@ -108,7 +108,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -120,7 +120,7 @@ def create_pydantic_field( def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -136,7 +136,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -146,7 +146,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -160,9 +160,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field_name)) register_relation_in_alias_manager(table_name, field) @@ -170,7 +170,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -179,7 +179,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - type_: Type[BaseField], field: str, attrs: dict + type_: Type[BaseField], field: str, attrs: dict ) -> dict: def_value = type_.default_value() curr_def_value = attrs.get(field, "NONE") @@ -208,7 +208,7 @@ def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: def populate_meta_orm_model_fields( - attrs: dict, new_model: Type["Model"] + attrs: dict, new_model: Type["Model"] ) -> Type["Model"]: model_fields = { field_name: field @@ -220,7 +220,7 @@ def populate_meta_orm_model_fields( def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = ( @@ -246,7 +246,7 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( @@ -288,7 +288,7 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A def populate_choices_validators( # noqa CCR001 - model: Type["Model"], attrs: Dict + model: Type["Model"], attrs: Dict ) -> None: if model_initialized_and_has_model_fields(model): for _, field in model.Meta.model_fields.items(): @@ -301,7 +301,7 @@ def populate_choices_validators( # noqa CCR001 class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 61d9bc8..7acfc0a 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -5,7 +5,7 @@ import ormar from ormar.exceptions import RelationshipInstanceError from ormar.fields import BaseField, ManyToManyField from ormar.fields.foreign_key import ForeignKeyField -from ormar.models.metaclass import ModelMeta +from ormar.models.metaclass import ModelMeta, expand_reverse_relationships if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -76,10 +76,10 @@ class ModelTableProxy: related_names = set() for name, field in cls.Meta.model_fields.items(): if ( - inspect.isclass(field) - and issubclass(field, ForeignKeyField) - and not issubclass(field, ManyToManyField) - and not field.virtual + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and not issubclass(field, ManyToManyField) + and not field.virtual ): related_names.add(name) return related_names @@ -91,9 +91,9 @@ class ModelTableProxy: related_names = set() for name, field in cls.Meta.model_fields.items(): if ( - inspect.isclass(field) - and issubclass(field, ForeignKeyField) - and field.nullable + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and field.nullable ): related_names.add(name) return related_names @@ -113,8 +113,9 @@ class ModelTableProxy: @staticmethod def resolve_relation_name( - item: Union["NewBaseModel", Type["NewBaseModel"]], - related: Union["NewBaseModel", Type["NewBaseModel"]], + item: Union["NewBaseModel", Type["NewBaseModel"]], + related: Union["NewBaseModel", Type["NewBaseModel"]], + register_missing: bool = True ) -> str: for name, field in item.Meta.model_fields.items(): if issubclass(field, ForeignKeyField): @@ -123,13 +124,18 @@ class ModelTableProxy: # so we need to compare Meta too as this one is copied as is if field.to == related.__class__ or field.to.Meta == related.Meta: return name + # fallback for not registered relation + if register_missing: + expand_reverse_relationships(related.__class__) + return ModelTableProxy.resolve_relation_name(item, related, register_missing=False) + raise ValueError( f"No relation between {item.get_name()} and {related.get_name()}" ) # pragma nocover @staticmethod def resolve_relation_field( - item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] + item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] ) -> Union[Type[BaseField], Type[ForeignKeyField]]: name = ModelTableProxy.resolve_relation_name(item, related) to_field = item.Meta.model_fields.get(name) @@ -144,9 +150,9 @@ class ModelTableProxy: def translate_columns_to_aliases(cls, new_kwargs: dict) -> dict: for field_name, field in cls.Meta.model_fields.items(): if ( - field_name in new_kwargs - and field.name is not None - and field.name != field_name + field_name in new_kwargs + and field.name is not None + and field.name != field_name ): new_kwargs[field.name] = new_kwargs.pop(field_name) return new_kwargs @@ -173,12 +179,12 @@ class ModelTableProxy: for field in one.Meta.model_fields.keys(): current_field = getattr(one, field) if isinstance(current_field, list) and not isinstance( - current_field, ormar.Model + current_field, ormar.Model ): setattr(other, field, current_field + getattr(other, field)) elif ( - isinstance(current_field, ormar.Model) - and current_field.pk == getattr(other, field).pk + isinstance(current_field, ormar.Model) + and current_field.pk == getattr(other, field).pk ): setattr( other, @@ -189,10 +195,10 @@ class ModelTableProxy: @staticmethod def _get_not_nested_columns_from_fields( - model: Type["Model"], - fields: List, - column_names: List[str], - use_alias: bool = False, + model: Type["Model"], + fields: List, + column_names: List[str], + use_alias: bool = False, ) -> List[str]: fields = [model.get_column_alias(k) if not use_alias else k for k in fields] columns = [name for name in fields if "__" not in name and name in column_names] @@ -200,11 +206,11 @@ class ModelTableProxy: @staticmethod def _get_nested_columns_from_fields( - model: Type["Model"], fields: List, use_alias: bool = False, + model: Type["Model"], fields: List, use_alias: bool = False, ) -> List[str]: model_name = f"{model.get_name()}__" columns = [ - name[(name.find(model_name) + len(model_name)) :] # noqa: E203 + name[(name.find(model_name) + len(model_name)):] # noqa: E203 for name in fields if f"{model.get_name()}__" in name ] @@ -213,10 +219,10 @@ class ModelTableProxy: @staticmethod def own_table_columns( - model: Type["Model"], - fields: List, - nested: bool = False, - use_alias: bool = False, + model: Type["Model"], + fields: List, + nested: bool = False, + use_alias: bool = False, ) -> List[str]: column_names = [ model.get_column_name_from_alias(col.name) if use_alias else col.name diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 88664fb..2809db7 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -132,12 +132,12 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass return super().__getattribute__(item) def _extract_related_model_instead_of_field( - self, item: str + self, item: str ) -> Optional[Union["Model", List["Model"]]]: alias = self.get_column_alias(item) if alias in self._orm: return self._orm.get(alias) - return None + return None # pragma no cover def __eq__(self, other: object) -> bool: if isinstance(other, NewBaseModel): @@ -146,9 +146,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def __same__(self, other: "NewBaseModel") -> bool: return ( - self._orm_id == other._orm_id - or self.dict() == other.dict() - or (self.pk == other.pk and self.pk is not None) + self._orm_id == other._orm_id + or self.dict() == other.dict() + or (self.pk == other.pk and self.pk is not None) ) @classmethod @@ -170,16 +170,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass self._orm.remove_parent(self, name) def dict( # noqa A003 - self, - *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index a176bec..c716d61 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -17,9 +17,9 @@ if TYPE_CHECKING: # pragma no cover class RelationsManager: def __init__( - self, - related_fields: List[Type[ForeignKeyField]] = None, - owner: "NewBaseModel" = None, + self, + related_fields: List[Type[ForeignKeyField]] = None, + owner: "NewBaseModel" = None, ) -> None: self.owner = proxy(owner) self._related_fields = related_fields or [] @@ -40,6 +40,8 @@ class RelationsManager: to=field.to, through=getattr(field, "through", None), ) + if field.name not in self._related_names: + self._related_names.append(field.name) def __contains__(self, item: str) -> bool: return item in self._related_names @@ -74,7 +76,7 @@ class RelationsManager: child_relation.add(parent) def remove( - self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]] + self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]] ) -> None: relation = self._get(name) if relation: @@ -82,7 +84,7 @@ class RelationsManager: @staticmethod def remove_parent( - item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model" + item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model" ) -> None: related_model = name rel_name = item.resolve_relation_name(item, related_model) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 88130d5..f658e57 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -6,7 +6,7 @@ from ormar.relations.querysetproxy import QuerysetProxy if TYPE_CHECKING: # pragma no cover from ormar import Model - from ormar.relations import Relation + from ormar.relations import Relation, register_missing_relation from ormar.queryset import QuerySet @@ -33,8 +33,8 @@ class RelationProxy(list): def _check_if_queryset_is_initialized(self) -> bool: return ( - hasattr(self.queryset_proxy, "queryset") - and self.queryset_proxy.queryset is not None + hasattr(self.queryset_proxy, "queryset") + and self.queryset_proxy.queryset is not None ) def _set_queryset(self) -> "QuerySet": @@ -48,8 +48,8 @@ class RelationProxy(list): kwargs = {f"{owner_table}__{pkname}": pk_value} queryset = ( ormar.QuerySet(model_cls=self.relation.to) - .select_related(owner_table) - .filter(**kwargs) + .select_related(owner_table) + .filter(**kwargs) ) return queryset @@ -72,4 +72,6 @@ class RelationProxy(list): if self.relation._type == ormar.RelationType.MULTIPLE: await self.queryset_proxy.create_through_instance(item) rel_name = item.resolve_relation_name(item, self._owner) + if not rel_name in item._orm: + item._orm._add_relation(item.Meta.model_fields[rel_name]) setattr(item, rel_name, self._owner) diff --git a/tests/test_fastapi_docs.py b/tests/test_fastapi_docs.py new file mode 100644 index 0000000..c6cd35b --- /dev/null +++ b/tests/test_fastapi_docs.py @@ -0,0 +1,125 @@ +from typing import List + +import databases +import pytest +import sqlalchemy +from fastapi import FastAPI +from starlette.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +class LocalMeta: + metadata = metadata + database = database + + +class Category(ormar.Model): + class Meta(LocalMeta): + tablename = "categories" + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + + +class ItemsXCategories(ormar.Model): + class Meta(LocalMeta): + tablename = 'items_x_categories' + + +class Item(ormar.Model): + class Meta(LocalMeta): + pass + + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + categories: ormar.ManyToMany(Category, through=ItemsXCategories) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@app.get("/items/", response_model=List[Item]) +async def get_items(): + items = await Item.objects.select_related("categories").all() + return items + + +@app.post("/items/", response_model=Item) +async def create_item(item: Item): + await item.save() + return item + + +@app.post("/items/add_category/", response_model=Item) +async def create_item(item: Item, category: Category): + await item.categories.add(category) + return item + + +@app.post("/categories/", response_model=Category) +async def create_category(category: Category): + await category.save() + return category + + +def test_all_endpoints(): + client = TestClient(app) + with client as client: + response = client.post("/categories/", json={"name": "test cat"}) + category = response.json() + response = client.post("/categories/", json={"name": "test cat2"}) + category2 = response.json() + + response = client.post( + "/items/", json={"name": "test", "id": 1} + ) + item = Item(**response.json()) + assert item.pk is not None + + response = client.post( + "/items/add_category/", json={"item": item.dict(), "category": category} + ) + item = Item(**response.json()) + assert len(item.categories) == 1 + assert item.categories[0].name == 'test cat' + + client.post( + "/items/add_category/", json={"item": item.dict(), "category": category2} + ) + + response = client.get("/items/") + items = [Item(**item) for item in response.json()] + assert items[0] == item + assert len(items[0].categories) == 2 + assert items[0].categories[0].name == 'test cat' + assert items[0].categories[1].name == 'test cat2' + + response = client.get("/docs/") + assert response.status_code == 200 + assert b'