From 929e979d37c1d934256594d4b832086c6382f261 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 19 Mar 2021 16:52:47 +0100 Subject: [PATCH] improve types -> make queryset generic --- ormar/fields/foreign_key.py | 12 ++-- ormar/fields/many_to_many.py | 2 +- ormar/models/__init__.py | 4 +- ormar/models/helpers/relations.py | 2 +- ormar/models/helpers/sqlalchemy.py | 6 +- ormar/models/metaclass.py | 13 +++- ormar/models/model.py | 3 - ormar/queryset/queryset.py | 74 +++++++++----------- tests/test_order_by.py | 2 +- tests/test_types.py | 108 +++++++++++++++++++++++++++++ 10 files changed, 165 insertions(+), 61 deletions(-) create mode 100644 tests/test_types.py diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 9458ce5..65a5311 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -15,7 +15,7 @@ from ormar.exceptions import ModelDefinitionError, RelationshipInstanceError from ormar.fields.base import BaseField if TYPE_CHECKING: # pragma no cover - from ormar.models import Model, NewBaseModel + from ormar.models import Model, NewBaseModel, T from ormar.fields import ManyToManyField if sys.version_info < (3, 7): @@ -24,7 +24,7 @@ if TYPE_CHECKING: # pragma no cover ToType = Union[Type["Model"], "ForwardRef"] -def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": +def create_dummy_instance(fk: Type["T"], pk: Any = None) -> "T": """ Ormar never returns you a raw data. So if you have a related field that has a value populated @@ -55,7 +55,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def create_dummy_model( - base_model: Type["Model"], + base_model: Type["T"], pk_field: Union[BaseField, "ForeignKeyField", "ManyToManyField"], ) -> Type["BaseModel"]: """ @@ -83,7 +83,7 @@ def create_dummy_model( def populate_fk_params_based_on_to_model( - to: Type["Model"], nullable: bool, onupdate: str = None, ondelete: str = None, + to: Type["T"], nullable: bool, onupdate: str = None, ondelete: str = None, ) -> Tuple[Any, List, Any]: """ Based on target to model to which relation leads to populates the type of the @@ -169,7 +169,7 @@ class ForeignKeyConstraint: def ForeignKey( # noqa CFQ002 - to: "ToType", + to: Type["T"], *, name: str = None, unique: bool = False, @@ -179,7 +179,7 @@ def ForeignKey( # noqa CFQ002 onupdate: str = None, ondelete: str = None, **kwargs: Any, -) -> Any: +) -> "T": """ Despite a name it's a function that returns constructed ForeignKeyField. This function is actually used in model declaration (as ormar.ForeignKey(ToModel)). diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 4f44121..0ad11a5 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -65,7 +65,7 @@ def ManyToMany( unique: bool = False, virtual: bool = False, **kwargs: Any, -) -> Any: +): """ Despite a name it's a function that returns constructed ManyToManyField. This function is actually used in model declaration diff --git a/ormar/models/__init__.py b/ormar/models/__init__.py index eb6bdd7..694043e 100644 --- a/ormar/models/__init__.py +++ b/ormar/models/__init__.py @@ -6,7 +6,7 @@ ass well as vast number of helper functions for pydantic, sqlalchemy and relatio from ormar.models.newbasemodel import NewBaseModel # noqa I100 from ormar.models.model_row import ModelRow # noqa I100 -from ormar.models.model import Model # noqa I100 +from ormar.models.model import Model, T # noqa I100 from ormar.models.excludable import ExcludableItems # noqa I100 -__all__ = ["NewBaseModel", "Model", "ModelRow", "ExcludableItems"] +__all__ = ["NewBaseModel", "Model", "ModelRow", "ExcludableItems", "T"] diff --git a/ormar/models/helpers/relations.py b/ormar/models/helpers/relations.py index ac35d9a..2eb01c2 100644 --- a/ormar/models/helpers/relations.py +++ b/ormar/models/helpers/relations.py @@ -117,7 +117,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None: register_through_shortcut_fields(model_field=model_field) adjust_through_many_to_many_model(model_field=model_field) else: - model_field.to.Meta.model_fields[related_name] = ForeignKey( + model_field.to.Meta.model_fields[related_name] = ForeignKey( # type: ignore model_field.owner, real_name=related_name, virtual=True, diff --git a/ormar/models/helpers/sqlalchemy.py b/ormar/models/helpers/sqlalchemy.py index fdb23c2..527536e 100644 --- a/ormar/models/helpers/sqlalchemy.py +++ b/ormar/models/helpers/sqlalchemy.py @@ -26,14 +26,14 @@ def adjust_through_many_to_many_model(model_field: "ManyToManyField") -> None: """ parent_name = model_field.default_target_field_name() child_name = model_field.default_source_field_name() - - model_field.through.Meta.model_fields[parent_name] = ormar.ForeignKey( + model_fields = model_field.through.Meta.model_fields + model_fields[parent_name] = ormar.ForeignKey( # type: ignore model_field.to, real_name=parent_name, ondelete="CASCADE", owner=model_field.through, ) - model_field.through.Meta.model_fields[child_name] = ormar.ForeignKey( + model_fields[child_name] = ormar.ForeignKey( # type: ignore model_field.owner, real_name=child_name, ondelete="CASCADE", diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 6c529f2..34fd454 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -18,6 +18,7 @@ from sqlalchemy.sql.schema import ColumnCollectionConstraint import ormar # noqa I100 from ormar import ModelDefinitionError # noqa I100 +from ormar.exceptions import ModelError from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.many_to_many import ManyToManyField @@ -44,6 +45,7 @@ from ormar.signals import Signal, SignalEmitter if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.models import T CONFIG_KEY = "Config" PARSED_FIELDS_KEY = "__parsed_fields__" @@ -545,6 +547,15 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): field_name=field_name, model=new_model ) new_model.Meta.alias_manager = alias_manager - new_model.objects = QuerySet(new_model) return new_model + + @property + def objects(cls: Type["T"]) -> "QuerySet[T]": # type: ignore + if cls.Meta.requires_ref_update: + raise ModelError( + f"Model {cls.get_name()} has not updated " + f"ForwardRefs. \nBefore using the model you " + f"need to call update_forward_refs()." + ) + return QuerySet(model_cls=cls) diff --git a/ormar/models/model.py b/ormar/models/model.py index a0c2abc..6a46696 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -15,8 +15,6 @@ from ormar.models import NewBaseModel # noqa I100 from ormar.models.metaclass import ModelMeta from ormar.models.model_row import ModelRow -if TYPE_CHECKING: # pragma nocover - from ormar import QuerySet T = TypeVar("T", bound="Model") @@ -25,7 +23,6 @@ class Model(ModelRow): __abstract__ = False if TYPE_CHECKING: # pragma nocover Meta: ModelMeta - objects: "QuerySet" def __repr__(self) -> str: # pragma nocover _repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()} diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 515f425..a3a98ab 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1,13 +1,13 @@ from typing import ( Any, Dict, - List, + Generic, List, Optional, Sequence, Set, TYPE_CHECKING, Type, - Union, + TypeVar, Union, cast, ) @@ -26,19 +26,22 @@ from ormar.queryset.query import Query if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.models import T from ormar.models.metaclass import ModelMeta from ormar.relations.querysetproxy import QuerysetProxy from ormar.models.excludable import ExcludableItems +else: + T = TypeVar("T", bound="Model") -class QuerySet: +class QuerySet(Generic[T]): """ Main class to perform database queries, exposed on each model as objects attribute. """ def __init__( # noqa CFQ002 self, - model_cls: Optional[Type["Model"]] = None, + model_cls: Optional[Type["T"]] = None, filter_clauses: List = None, exclude_clauses: List = None, select_related: List = None, @@ -62,21 +65,6 @@ class QuerySet: self.order_bys = order_bys or [] self.limit_sql_raw = limit_raw_sql - def __get__( - self, - instance: Optional[Union["QuerySet", "QuerysetProxy"]], - owner: Union[Type["Model"], Type["QuerysetProxy"]], - ) -> "QuerySet": - if issubclass(owner, ormar.Model): - if owner.Meta.requires_ref_update: - raise ModelError( - f"Model {owner.get_name()} has not updated " - f"ForwardRefs. \nBefore using the model you " - f"need to call update_forward_refs()." - ) - owner = cast(Type["Model"], owner) - return self.__class__(model_cls=owner) - return self.__class__() # pragma: no cover @property def model_meta(self) -> "ModelMeta": @@ -91,7 +79,7 @@ class QuerySet: return self.model_cls.Meta @property - def model(self) -> Type["Model"]: + def model(self) -> Type["T"]: """ Shortcut to model class set on QuerySet. @@ -148,8 +136,8 @@ class QuerySet: ) async def _prefetch_related_models( - self, models: List[Optional["Model"]], rows: List - ) -> List[Optional["Model"]]: + self, models: List[Optional["T"]], rows: List + ) -> List[Optional["T"]]: """ Performs prefetch query for selected models names. @@ -169,7 +157,7 @@ class QuerySet: ) return await query.prefetch_related(models=models, rows=rows) # type: ignore - def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]: + def _process_query_result_rows(self, rows: List) -> List[Optional["T"]]: """ Process database rows and initialize ormar Model from each of the rows. @@ -190,7 +178,7 @@ class QuerySet: ] if result_rows: return self.model.merge_instances_list(result_rows) # type: ignore - return result_rows + return cast(List[Optional["T"]], result_rows) def _resolve_filter_groups(self, groups: Any) -> List[FilterGroup]: """ @@ -221,7 +209,7 @@ class QuerySet: return filter_groups @staticmethod - def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None: + def check_single_result_rows_count(rows: Sequence[Optional["T"]]) -> None: """ Verifies if the result has one and only one row. @@ -286,7 +274,7 @@ class QuerySet: def filter( # noqa: A003 self, *args: Any, _exclude: bool = False, **kwargs: Any - ) -> "QuerySet": + ) -> "QuerySet[T]": """ Allows you to filter by any `Model` attribute/field as well as to fetch instances, with a filter across an FK relationship. @@ -337,7 +325,7 @@ class QuerySet: select_related=select_related, ) - def exclude(self, *args: Any, **kwargs: Any) -> "QuerySet": # noqa: A003 + def exclude(self, *args: Any, **kwargs: Any) -> "QuerySet[T]": # noqa: A003 """ Works exactly the same as filter and all modifiers (suffixes) are the same, but returns a *not* condition. @@ -358,7 +346,7 @@ class QuerySet: """ return self.filter(_exclude=True, *args, **kwargs) - def select_related(self, related: Union[List, str]) -> "QuerySet": + def select_related(self, related: Union[List, str]) -> "QuerySet[T]": """ Allows to prefetch related models during the same query. @@ -381,7 +369,7 @@ class QuerySet: related = sorted(list(set(list(self._select_related) + related))) return self.rebuild_self(select_related=related,) - def prefetch_related(self, related: Union[List, str]) -> "QuerySet": + def prefetch_related(self, related: Union[List, str]) -> "QuerySet[T]": """ Allows to prefetch related models during query - but opposite to `select_related` each subsequent model is fetched in a separate database query. @@ -407,7 +395,7 @@ class QuerySet: def fields( self, columns: Union[List, str, Set, Dict], _is_exclude: bool = False - ) -> "QuerySet": + ) -> "QuerySet[T]": """ With `fields()` you can select subset of model columns to limit the data load. @@ -461,7 +449,7 @@ class QuerySet: return self.rebuild_self(excludable=excludable,) - def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet": + def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet[T]": """ With `exclude_fields()` you can select subset of model columns that will be excluded to limit the data load. @@ -490,7 +478,7 @@ class QuerySet: """ return self.fields(columns=columns, _is_exclude=True) - def order_by(self, columns: Union[List, str]) -> "QuerySet": + def order_by(self, columns: Union[List, str]) -> "QuerySet[T]": """ With `order_by()` you can order the results from database based on your choice of fields. @@ -680,7 +668,7 @@ class QuerySet: ) return await self.database.execute(expr) - def paginate(self, page: int, page_size: int = 20) -> "QuerySet": + def paginate(self, page: int, page_size: int = 20) -> "QuerySet[T]": """ You can paginate the result which is a combination of offset and limit clauses. Limit is set to page size and offset is set to (page-1) * page_size. @@ -699,7 +687,7 @@ class QuerySet: query_offset = (page - 1) * page_size return self.rebuild_self(limit_count=limit_count, offset=query_offset,) - def limit(self, limit_count: int, limit_raw_sql: bool = None) -> "QuerySet": + def limit(self, limit_count: int, limit_raw_sql: bool = None) -> "QuerySet[T]": """ You can limit the results to desired number of parent models. @@ -716,7 +704,7 @@ class QuerySet: limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql return self.rebuild_self(limit_count=limit_count, limit_raw_sql=limit_raw_sql,) - def offset(self, offset: int, limit_raw_sql: bool = None) -> "QuerySet": + def offset(self, offset: int, limit_raw_sql: bool = None) -> "QuerySet[T]": """ You can also offset the results by desired number of main models. @@ -733,7 +721,7 @@ class QuerySet: limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql return self.rebuild_self(offset=offset, limit_raw_sql=limit_raw_sql,) - async def first(self, **kwargs: Any) -> "Model": + async def first(self, **kwargs: Any) -> "T": """ Gets the first row from the db ordered by primary key column ascending. @@ -764,7 +752,7 @@ class QuerySet: self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore - async def get(self, **kwargs: Any) -> "Model": + async def get(self, **kwargs: Any) -> "T": """ Get's the first row from the db meeting the criteria set by kwargs. @@ -803,7 +791,7 @@ class QuerySet: self.check_single_result_rows_count(processed_rows) return processed_rows[0] # type: ignore - async def get_or_create(self, **kwargs: Any) -> "Model": + async def get_or_create(self, **kwargs: Any) -> "T": """ Combination of create and get methods. @@ -821,7 +809,7 @@ class QuerySet: except NoMatch: return await self.create(**kwargs) - async def update_or_create(self, **kwargs: Any) -> "Model": + async def update_or_create(self, **kwargs: Any) -> "T": """ Updates the model, or in case there is no match in database creates a new one. @@ -838,7 +826,7 @@ class QuerySet: model = await self.get(pk=kwargs[pk_name]) return await model.update(**kwargs) - async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003 + async def all(self, **kwargs: Any) -> List[Optional["T"]]: # noqa: A003 """ Returns all rows from a database for given model for set filter options. @@ -862,7 +850,7 @@ class QuerySet: return result_rows - async def create(self, **kwargs: Any) -> "Model": + async def create(self, **kwargs: Any) -> "T": """ Creates the model instance, saves it in a database and returns the updates model (with pk populated if not passed and autoincrement is set). @@ -905,7 +893,7 @@ class QuerySet: ) return instance - async def bulk_create(self, objects: List["Model"]) -> None: + async def bulk_create(self, objects: List["T"]) -> None: """ Performs a bulk update in one database session to speed up the process. @@ -931,7 +919,7 @@ class QuerySet: objt.set_save_status(True) async def bulk_update( # noqa: CCR001 - self, objects: List["Model"], columns: List[str] = None + self, objects: List["T"], columns: List[str] = None ) -> None: """ Performs bulk update in one database session to speed up the process. diff --git a/tests/test_order_by.py b/tests/test_order_by.py index ad688d2..7c2f10d 100644 --- a/tests/test_order_by.py +++ b/tests/test_order_by.py @@ -50,7 +50,7 @@ class AliasTest(ormar.Model): id: int = ormar.Integer(name="alias_id", primary_key=True) name: str = ormar.String(name="alias_name", max_length=100) - nested: str = ormar.ForeignKey(AliasNested, name="nested_alias") + nested = ormar.ForeignKey(AliasNested, name="nested_alias") class Toy(ormar.Model): diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..e06daf1 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,108 @@ +from typing import Any, Optional, TYPE_CHECKING + +import databases +import pytest +import sqlalchemy + +import ormar +from ormar.relations.querysetproxy import QuerysetProxy +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL) +metadata = sqlalchemy.MetaData() + + +class BaseMeta(ormar.ModelMeta): + metadata = metadata + database = database + + +class Publisher(ormar.Model): + class Meta(BaseMeta): + tablename = "publishers" + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + + +class Author(ormar.Model): + class Meta(BaseMeta): + tablename = "authors" + order_by = ["-name"] + + id: int = ormar.Integer(primary_key=True) + name: str = ormar.String(max_length=100) + publishers = ormar.ManyToMany(Publisher) + + +class Book(ormar.Model): + class Meta(BaseMeta): + tablename = "books" + order_by = ["year", "-ranking"] + + id: int = ormar.Integer(primary_key=True) + author = ormar.ForeignKey(Author) + title: str = ormar.String(max_length=100) + year: int = ormar.Integer(nullable=True) + ranking: int = ormar.Integer(nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def assert_type(book: Book): + print(book) + + +@pytest.mark.asyncio +async def test_types() -> None: + async with database: + query = Book.objects + publisher = await Publisher(name="Test publisher").save() + author = await Author.objects.create(name="Test Author") + await author.publishers.add(publisher) + author2 = await Author.objects.select_related("publishers").get() + publishers = author2.publishers + publisher2 = await Publisher.objects.select_related("authors").get() + authors = publisher2.authors + assert authors[0] == author + for author in authors: + if TYPE_CHECKING: # pragma: no cover + reveal_type(author) # iter of relation proxy + book = await Book.objects.create(title="Test", author=author) + book2 = await Book.objects.select_related("author").get() + books = await Book.objects.select_related("author").all() + author_books = await author.books.all() + assert book.author.name == "Test Author" + assert book2.author.name == "Test Author" + if TYPE_CHECKING: # pragma: no cover + reveal_type(publisher) # model method + reveal_type(publishers) # many to many + reveal_type(publishers[0]) # item in m2m list + # getting relation without __getattribute__ + to_model = Publisher.Meta.model_fields['authors'].to + reveal_type( + publisher2._extract_related_model_instead_of_field("authors") + ) # TODO: wrong + reveal_type(publisher.Meta.model_fields['authors'].to) # TODO: wrong + reveal_type(authors) # reverse many to many # TODO: wrong + reveal_type(book2) # queryset get + reveal_type(books) # queryset all + reveal_type(book) # queryset - create + reveal_type(query) # queryset itself + reveal_type(book.author) # fk + reveal_type( + author.books.queryset_proxy + ) # queryset in querysetproxy # TODO: wrong + reveal_type(author.books) # reverse fk # TODO: wrong + reveal_type(author) # another test for queryset get different model + reveal_type(book.author.name) # field on related model + reveal_type(author_books) # querysetproxy for fk # TODO: wrong + reveal_type(author_books[0]) # item i qs proxy for fk # TODO: wrong + assert_type(book)