diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 0fed9e8..1f73a0d 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Type +from typing import Dict, TYPE_CHECKING, Type from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField @@ -6,6 +6,8 @@ from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar.models import Model +REF_PREFIX = "#/components/schemas/" + def ManyToMany( to: Type["Model"], @@ -31,6 +33,9 @@ def ManyToMany( pydantic_only=False, default=None, server_default=None, + __pydantic_model__=to, + # __origin__=List, + # __args__=[Optional[to]] ) return type("ManyToMany", (ManyToManyField, BaseField), namespace) @@ -38,3 +43,10 @@ def ManyToMany( class ManyToManyField(ForeignKeyField): through: Type["Model"] + + @classmethod + def __modify_schema__(cls, field_schema: Dict) -> None: + field_schema["type"] = "array" + field_schema["title"] = cls.name.title() + field_schema["definitions"] = {f"{cls.to.__name__}": cls.to.schema()} + field_schema["items"] = {"$ref": f"{REF_PREFIX}{cls.to.__name__}"} diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 8bb0168..d7c57b9 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -41,7 +41,7 @@ def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> def register_many_to_many_relation_on_build( - table_name: str, field: Type[ManyToManyField] + table_name: str, field: Type[ManyToManyField] ) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( @@ -50,11 +50,11 @@ def register_many_to_many_relation_on_build( def reverse_field_not_already_registered( - child: Type["Model"], child_model_name: str, parent_model: Type["Model"] + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] ) -> bool: return ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ) @@ -65,7 +65,7 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if reverse_field_not_already_registered( - child, child_model_name, parent_model + child, child_model_name, parent_model ): register_reverse_model_fields( parent_model, child, child_model_name, model_field @@ -73,10 +73,10 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: def register_reverse_model_fields( - model: Type["Model"], - child: Type["Model"], - child_model_name: str, - model_field: Type["ForeignKeyField"], + model: Type["Model"], + child: Type["Model"], + child_model_name: str, + model_field: Type["ForeignKeyField"], ) -> None: if issubclass(model_field, ManyToManyField): model.Meta.model_fields[child_model_name] = ManyToMany( @@ -91,7 +91,7 @@ def register_reverse_model_fields( def adjust_through_many_to_many_model( - model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], child: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.Meta.model_fields[model.get_name()] = ForeignKey( model, name=model.get_name(), ondelete="CASCADE" @@ -108,7 +108,7 @@ def adjust_through_many_to_many_model( def create_pydantic_field( - field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] + field_name: str, model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, @@ -120,7 +120,7 @@ def create_pydantic_field( def create_and_append_m2m_fk( - model: Type["Model"], model_field: Type[ManyToManyField] + model: Type["Model"], model_field: Type[ManyToManyField] ) -> None: column = sqlalchemy.Column( model.get_name(), @@ -136,7 +136,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: Optional[str] + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -146,7 +146,7 @@ def check_pk_column_validity( def sqlalchemy_columns_from_model_fields( - model_fields: Dict, table_name: str + model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: columns = [] pkname = None @@ -160,9 +160,9 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = check_pk_column_validity(field_name, field, pkname) if ( - not field.pydantic_only - and not field.virtual - and not issubclass(field, ManyToManyField) + not field.pydantic_only + and not field.virtual + and not issubclass(field, ManyToManyField) ): columns.append(field.get_column(field_name)) register_relation_in_alias_manager(table_name, field) @@ -170,7 +170,7 @@ def sqlalchemy_columns_from_model_fields( def register_relation_in_alias_manager( - table_name: str, field: Type[ForeignKeyField] + table_name: str, field: Type[ForeignKeyField] ) -> None: if issubclass(field, ManyToManyField): register_many_to_many_relation_on_build(table_name, field) @@ -179,7 +179,7 @@ def register_relation_in_alias_manager( def populate_default_pydantic_field_value( - type_: Type[BaseField], field: str, attrs: dict + type_: Type[BaseField], field: str, attrs: dict ) -> dict: def_value = type_.default_value() curr_def_value = attrs.get(field, "NONE") @@ -208,7 +208,7 @@ def extract_annotations_and_default_vals(attrs: dict, bases: Tuple) -> dict: def populate_meta_orm_model_fields( - attrs: dict, new_model: Type["Model"] + attrs: dict, new_model: Type["Model"] ) -> Type["Model"]: model_fields = { field_name: field @@ -220,7 +220,7 @@ def populate_meta_orm_model_fields( def populate_meta_tablename_columns_and_pk( - name: str, new_model: Type["Model"] + name: str, new_model: Type["Model"] ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = ( @@ -246,7 +246,7 @@ def populate_meta_tablename_columns_and_pk( def populate_meta_sqlalchemy_table_if_required( - new_model: Type["Model"], + new_model: Type["Model"], ) -> Type["Model"]: if not hasattr(new_model.Meta, "table"): new_model.Meta.table = sqlalchemy.Table( @@ -288,7 +288,7 @@ def choices_validator(cls: Type["Model"], values: Dict[str, Any]) -> Dict[str, A def populate_choices_validators( # noqa CCR001 - model: Type["Model"], attrs: Dict + model: Type["Model"], attrs: Dict ) -> None: if model_initialized_and_has_model_fields(model): for _, field in model.Meta.model_fields.items(): @@ -301,7 +301,7 @@ def populate_choices_validators( # noqa CCR001 class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__( # type: ignore - mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict + mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict ) -> "ModelMetaclass": attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 7acfc0a..c965154 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -76,10 +76,10 @@ class ModelTableProxy: related_names = set() for name, field in cls.Meta.model_fields.items(): if ( - inspect.isclass(field) - and issubclass(field, ForeignKeyField) - and not issubclass(field, ManyToManyField) - and not field.virtual + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and not issubclass(field, ManyToManyField) + and not field.virtual ): related_names.add(name) return related_names @@ -91,9 +91,9 @@ class ModelTableProxy: related_names = set() for name, field in cls.Meta.model_fields.items(): if ( - inspect.isclass(field) - and issubclass(field, ForeignKeyField) - and field.nullable + inspect.isclass(field) + and issubclass(field, ForeignKeyField) + and field.nullable ): related_names.add(name) return related_names @@ -112,10 +112,10 @@ class ModelTableProxy: return self_fields @staticmethod - def resolve_relation_name( - item: Union["NewBaseModel", Type["NewBaseModel"]], - related: Union["NewBaseModel", Type["NewBaseModel"]], - register_missing: bool = True + def resolve_relation_name( # noqa CCR001 + item: Union["NewBaseModel", Type["NewBaseModel"]], + related: Union["NewBaseModel", Type["NewBaseModel"]], + register_missing: bool = True, ) -> str: for name, field in item.Meta.model_fields.items(): if issubclass(field, ForeignKeyField): @@ -126,8 +126,10 @@ class ModelTableProxy: return name # fallback for not registered relation if register_missing: - expand_reverse_relationships(related.__class__) - return ModelTableProxy.resolve_relation_name(item, related, register_missing=False) + expand_reverse_relationships(related.__class__) # type: ignore + return ModelTableProxy.resolve_relation_name( + item, related, register_missing=False + ) raise ValueError( f"No relation between {item.get_name()} and {related.get_name()}" @@ -135,7 +137,7 @@ class ModelTableProxy: @staticmethod def resolve_relation_field( - item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] + item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] ) -> Union[Type[BaseField], Type[ForeignKeyField]]: name = ModelTableProxy.resolve_relation_name(item, related) to_field = item.Meta.model_fields.get(name) @@ -147,18 +149,18 @@ class ModelTableProxy: return to_field @classmethod - def translate_columns_to_aliases(cls, new_kwargs: dict) -> dict: + def translate_columns_to_aliases(cls, new_kwargs: Dict) -> Dict: for field_name, field in cls.Meta.model_fields.items(): if ( - field_name in new_kwargs - and field.name is not None - and field.name != field_name + field_name in new_kwargs + and field.name is not None + and field.name != field_name ): new_kwargs[field.name] = new_kwargs.pop(field_name) return new_kwargs @classmethod - def translate_aliases_to_columns(cls, new_kwargs: dict) -> dict: + def translate_aliases_to_columns(cls, new_kwargs: Dict) -> Dict: for field_name, field in cls.Meta.model_fields.items(): if field.name in new_kwargs and field.name != field_name: new_kwargs[field_name] = new_kwargs.pop(field.name) @@ -179,12 +181,12 @@ class ModelTableProxy: for field in one.Meta.model_fields.keys(): current_field = getattr(one, field) if isinstance(current_field, list) and not isinstance( - current_field, ormar.Model + current_field, ormar.Model ): setattr(other, field, current_field + getattr(other, field)) elif ( - isinstance(current_field, ormar.Model) - and current_field.pk == getattr(other, field).pk + isinstance(current_field, ormar.Model) + and current_field.pk == getattr(other, field).pk ): setattr( other, @@ -195,10 +197,10 @@ class ModelTableProxy: @staticmethod def _get_not_nested_columns_from_fields( - model: Type["Model"], - fields: List, - column_names: List[str], - use_alias: bool = False, + model: Type["Model"], + fields: List, + column_names: List[str], + use_alias: bool = False, ) -> List[str]: fields = [model.get_column_alias(k) if not use_alias else k for k in fields] columns = [name for name in fields if "__" not in name and name in column_names] @@ -206,11 +208,11 @@ class ModelTableProxy: @staticmethod def _get_nested_columns_from_fields( - model: Type["Model"], fields: List, use_alias: bool = False, + model: Type["Model"], fields: List, use_alias: bool = False, ) -> List[str]: model_name = f"{model.get_name()}__" columns = [ - name[(name.find(model_name) + len(model_name)):] # noqa: E203 + name[(name.find(model_name) + len(model_name)) :] # noqa: E203 for name in fields if f"{model.get_name()}__" in name ] @@ -219,10 +221,10 @@ class ModelTableProxy: @staticmethod def own_table_columns( - model: Type["Model"], - fields: List, - nested: bool = False, - use_alias: bool = False, + model: Type["Model"], + fields: List, + nested: bool = False, + use_alias: bool = False, ) -> List[str]: column_names = [ model.get_column_name_from_alias(col.name) if use_alias else col.name diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 2809db7..5cf28bc 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -100,7 +100,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass ) def __setattr__(self, name: str, value: Any) -> None: # noqa CCR001 - if name in self.__slots__: + if name in ("_orm_id", "_orm_saved", "_orm"): object.__setattr__(self, name, value) elif name == "pk": object.__setattr__(self, self.Meta.pkname, value) @@ -132,7 +132,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass return super().__getattribute__(item) def _extract_related_model_instead_of_field( - self, item: str + self, item: str ) -> Optional[Union["Model", List["Model"]]]: alias = self.get_column_alias(item) if alias in self._orm: @@ -146,9 +146,9 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass def __same__(self, other: "NewBaseModel") -> bool: return ( - self._orm_id == other._orm_id - or self.dict() == other.dict() - or (self.pk == other.pk and self.pk is not None) + self._orm_id == other._orm_id + or self.dict() == other.dict() + or (self.pk == other.pk and self.pk is not None) ) @classmethod @@ -170,16 +170,16 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass self._orm.remove_parent(self, name) def dict( # noqa A003 - self, - *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - nested: bool = False + self, + *, + include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False ) -> "DictStrAny": # noqa: A003' dict_instance = super().dict( include=include, diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index c716d61..6e7eb24 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -17,9 +17,9 @@ if TYPE_CHECKING: # pragma no cover class RelationsManager: def __init__( - self, - related_fields: List[Type[ForeignKeyField]] = None, - owner: "NewBaseModel" = None, + self, + related_fields: List[Type[ForeignKeyField]] = None, + owner: "NewBaseModel" = None, ) -> None: self.owner = proxy(owner) self._related_fields = related_fields or [] @@ -76,7 +76,7 @@ class RelationsManager: child_relation.add(parent) def remove( - self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]] + self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]] ) -> None: relation = self._get(name) if relation: @@ -84,7 +84,7 @@ class RelationsManager: @staticmethod def remove_parent( - item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model" + item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model" ) -> None: related_model = name rel_name = item.resolve_relation_name(item, related_model) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index f658e57..29e4b97 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -6,7 +6,7 @@ from ormar.relations.querysetproxy import QuerysetProxy if TYPE_CHECKING: # pragma no cover from ormar import Model - from ormar.relations import Relation, register_missing_relation + from ormar.relations import Relation from ormar.queryset import QuerySet @@ -33,8 +33,8 @@ class RelationProxy(list): def _check_if_queryset_is_initialized(self) -> bool: return ( - hasattr(self.queryset_proxy, "queryset") - and self.queryset_proxy.queryset is not None + hasattr(self.queryset_proxy, "queryset") + and self.queryset_proxy.queryset is not None ) def _set_queryset(self) -> "QuerySet": @@ -48,8 +48,8 @@ class RelationProxy(list): kwargs = {f"{owner_table}__{pkname}": pk_value} queryset = ( ormar.QuerySet(model_cls=self.relation.to) - .select_related(owner_table) - .filter(**kwargs) + .select_related(owner_table) + .filter(**kwargs) ) return queryset @@ -72,6 +72,6 @@ class RelationProxy(list): if self.relation._type == ormar.RelationType.MULTIPLE: await self.queryset_proxy.create_through_instance(item) rel_name = item.resolve_relation_name(item, self._owner) - if not rel_name in item._orm: + if rel_name not in item._orm: item._orm._add_relation(item.Meta.model_fields[rel_name]) setattr(item, rel_name, self._owner) diff --git a/tests/test_aliases.py b/tests/test_aliases.py index f169c30..b48bb3f 100644 --- a/tests/test_aliases.py +++ b/tests/test_aliases.py @@ -15,10 +15,10 @@ class Child(ormar.Model): metadata = metadata database = database - id: ormar.Integer(name='child_id', primary_key=True) - first_name: ormar.String(name='fname', max_length=100) - last_name: ormar.String(name='lname', max_length=100) - born_year: ormar.Integer(name='year_born', nullable=True) + id: ormar.Integer(name="child_id", primary_key=True) + first_name: ormar.String(name="fname", max_length=100) + last_name: ormar.String(name="lname", max_length=100) + born_year: ormar.Integer(name="year_born", nullable=True) class ArtistChildren(ormar.Model): @@ -34,10 +34,10 @@ class Artist(ormar.Model): metadata = metadata database = database - id: ormar.Integer(name='artist_id', primary_key=True) - first_name: ormar.String(name='fname', max_length=100) - last_name: ormar.String(name='lname', max_length=100) - born_year: ormar.Integer(name='year') + id: ormar.Integer(name="artist_id", primary_key=True) + first_name: ormar.String(name="fname", max_length=100) + last_name: ormar.String(name="lname", max_length=100) + born_year: ormar.Integer(name="year") children: ormar.ManyToMany(Child, through=ArtistChildren) @@ -47,9 +47,9 @@ class Album(ormar.Model): metadata = metadata database = database - id: ormar.Integer(name='album_id', primary_key=True) - name: ormar.String(name='album_name', max_length=100) - artist: ormar.ForeignKey(Artist, name='artist_id') + id: ormar.Integer(name="album_id", primary_key=True) + name: ormar.String(name="album_name", max_length=100) + artist: ormar.ForeignKey(Artist, name="artist_id") @pytest.fixture(autouse=True, scope="module") @@ -62,70 +62,87 @@ def create_test_database(): def test_table_structure(): - assert 'album_id' in [x.name for x in Album.Meta.table.columns] - assert 'album_name' in [x.name for x in Album.Meta.table.columns] - assert 'fname' in [x.name for x in Artist.Meta.table.columns] - assert 'lname' in [x.name for x in Artist.Meta.table.columns] - assert 'year' in [x.name for x in Artist.Meta.table.columns] + assert "album_id" in [x.name for x in Album.Meta.table.columns] + assert "album_name" in [x.name for x in Album.Meta.table.columns] + assert "fname" in [x.name for x in Artist.Meta.table.columns] + assert "lname" in [x.name for x in Artist.Meta.table.columns] + assert "year" in [x.name for x in Artist.Meta.table.columns] @pytest.mark.asyncio async def test_working_with_aliases(): async with database: async with database.transaction(force_rollback=True): - artist = await Artist.objects.create(first_name='Ted', last_name='Mosbey', born_year=1975) + artist = await Artist.objects.create( + first_name="Ted", last_name="Mosbey", born_year=1975 + ) await Album.objects.create(name="Aunt Robin", artist=artist) - await artist.children.create(first_name='Son', last_name='1', born_year=1990) - await artist.children.create(first_name='Son', last_name='2', born_year=1995) + await artist.children.create( + first_name="Son", last_name="1", born_year=1990 + ) + await artist.children.create( + first_name="Son", last_name="2", born_year=1995 + ) - album = await Album.objects.select_related('artist').first() - assert album.artist.last_name == 'Mosbey' + album = await Album.objects.select_related("artist").first() + assert album.artist.last_name == "Mosbey" assert album.artist.id is not None - assert album.artist.first_name == 'Ted' + assert album.artist.first_name == "Ted" assert album.artist.born_year == 1975 - assert album.name == 'Aunt Robin' + assert album.name == "Aunt Robin" - artist = await Artist.objects.select_related('children').get() + artist = await Artist.objects.select_related("children").get() assert len(artist.children) == 2 - assert artist.children[0].first_name == 'Son' - assert artist.children[1].last_name == '2' + assert artist.children[0].first_name == "Son" + assert artist.children[1].last_name == "2" - await artist.update(last_name='Bundy') + await artist.update(last_name="Bundy") await Artist.objects.filter(pk=artist.pk).update(born_year=1974) - artist = await Artist.objects.select_related('children').get() - assert artist.last_name == 'Bundy' + artist = await Artist.objects.select_related("children").get() + assert artist.last_name == "Bundy" assert artist.born_year == 1974 - artist = await Artist.objects.select_related('children').fields( - ['first_name', 'last_name', 'born_year', 'child__first_name', 'child__last_name']).get() + artist = ( + await Artist.objects.select_related("children") + .fields( + [ + "first_name", + "last_name", + "born_year", + "child__first_name", + "child__last_name", + ] + ) + .get() + ) assert artist.children[0].born_year is None @pytest.mark.asyncio async def test_bulk_operations_and_fields(): async with database: - d1 = Child(first_name='Daughter', last_name='1', born_year=1990) - d2 = Child(first_name='Daughter', last_name='2', born_year=1991) + d1 = Child(first_name="Daughter", last_name="1", born_year=1990) + d2 = Child(first_name="Daughter", last_name="2", born_year=1991) await Child.objects.bulk_create([d1, d2]) - children = await Child.objects.filter(first_name='Daughter').all() + children = await Child.objects.filter(first_name="Daughter").all() assert len(children) == 2 - assert children[0].last_name == '1' + assert children[0].last_name == "1" for child in children: child.born_year = child.born_year - 100 await Child.objects.bulk_update(children) - children = await Child.objects.filter(first_name='Daughter').all() + children = await Child.objects.filter(first_name="Daughter").all() assert len(children) == 2 assert children[0].born_year == 1890 - children = await Child.objects.fields(['first_name', 'last_name']).all() + children = await Child.objects.fields(["first_name", "last_name"]).all() assert len(children) == 2 for child in children: assert child.born_year is None @@ -140,17 +157,21 @@ async def test_bulk_operations_and_fields(): async def test_working_with_aliases_get_or_create(): async with database: async with database.transaction(force_rollback=True): - artist = await Artist.objects.get_or_create(first_name='Teddy', last_name='Bear', born_year=2020) + artist = await Artist.objects.get_or_create( + first_name="Teddy", last_name="Bear", born_year=2020 + ) assert artist.pk is not None - artist2 = await Artist.objects.get_or_create(first_name='Teddy', last_name='Bear', born_year=2020) + artist2 = await Artist.objects.get_or_create( + first_name="Teddy", last_name="Bear", born_year=2020 + ) assert artist == artist2 art3 = artist2.dict() - art3['born_year'] = 2019 + art3["born_year"] = 2019 await Artist.objects.update_or_create(**art3) - artist3 = await Artist.objects.get(last_name='Bear') + artist3 = await Artist.objects.get(last_name="Bear") assert artist3.born_year == 2019 artists = await Artist.objects.all() diff --git a/tests/test_fastapi_docs.py b/tests/test_fastapi_docs.py index c6cd35b..030815f 100644 --- a/tests/test_fastapi_docs.py +++ b/tests/test_fastapi_docs.py @@ -44,7 +44,7 @@ class Category(ormar.Model): class ItemsXCategories(ormar.Model): class Meta(LocalMeta): - tablename = 'items_x_categories' + tablename = "items_x_categories" class Item(ormar.Model): @@ -96,9 +96,7 @@ def test_all_endpoints(): response = client.post("/categories/", json={"name": "test cat2"}) category2 = response.json() - response = client.post( - "/items/", json={"name": "test", "id": 1} - ) + response = client.post("/items/", json={"name": "test", "id": 1}) item = Item(**response.json()) assert item.pk is not None @@ -107,7 +105,7 @@ def test_all_endpoints(): ) item = Item(**response.json()) assert len(item.categories) == 1 - assert item.categories[0].name == 'test cat' + assert item.categories[0].name == "test cat" client.post( "/items/add_category/", json={"item": item.dict(), "category": category2} @@ -117,9 +115,21 @@ def test_all_endpoints(): items = [Item(**item) for item in response.json()] assert items[0] == item assert len(items[0].categories) == 2 - assert items[0].categories[0].name == 'test cat' - assert items[0].categories[1].name == 'test cat2' + assert items[0].categories[0].name == "test cat" + assert items[0].categories[1].name == "test cat2" response = client.get("/docs/") assert response.status_code == 200 - assert b'FastAPI - Swagger UI' in response.content + assert b"FastAPI - Swagger UI" in response.content + + +def test_schema_modification(): + schema = Item.schema() + assert schema["properties"]["categories"]["type"] == "array" + assert schema["properties"]["categories"]["title"] == "Categories" + + +def test_schema_gen(): + schema = app.openapi() + assert "Category" in schema["components"]["schemas"] + assert "Item" in schema["components"]["schemas"] diff --git a/tests/test_models.py b/tests/test_models.py index 26de3f1..47b5d50 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -437,4 +437,3 @@ async def test_start_and_end_filters(): users = await User.objects.filter(name__endswith="igo").all() assert len(users) == 2 - diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index 3a5e909..b538b5c 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -119,9 +119,7 @@ def test_all_endpoints(): items = response.json() assert len(items) == 0 - client.post( - "/items/", json={"name": "test_2", "id": 2, "category": category} - ) + client.post("/items/", json={"name": "test_2", "id": 2, "category": category}) response = client.get("/items/") items = response.json() assert len(items) == 1 @@ -132,4 +130,3 @@ def test_all_endpoints(): response = client.get("/docs/") assert response.status_code == 200 -