improve types -> make queryset generic

This commit is contained in:
collerek
2021-03-19 16:52:47 +01:00
parent 32695ffa1d
commit 929e979d37
10 changed files with 165 additions and 61 deletions

View File

@ -15,7 +15,7 @@ from ormar.exceptions import ModelDefinitionError, RelationshipInstanceError
from ormar.fields.base import BaseField
if TYPE_CHECKING: # pragma no cover
from ormar.models import Model, NewBaseModel
from ormar.models import Model, NewBaseModel, T
from ormar.fields import ManyToManyField
if sys.version_info < (3, 7):
@ -24,7 +24,7 @@ if TYPE_CHECKING: # pragma no cover
ToType = Union[Type["Model"], "ForwardRef"]
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
def create_dummy_instance(fk: Type["T"], pk: Any = None) -> "T":
"""
Ormar never returns you a raw data.
So if you have a related field that has a value populated
@ -55,7 +55,7 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
def create_dummy_model(
base_model: Type["Model"],
base_model: Type["T"],
pk_field: Union[BaseField, "ForeignKeyField", "ManyToManyField"],
) -> Type["BaseModel"]:
"""
@ -83,7 +83,7 @@ def create_dummy_model(
def populate_fk_params_based_on_to_model(
to: Type["Model"], nullable: bool, onupdate: str = None, ondelete: str = None,
to: Type["T"], nullable: bool, onupdate: str = None, ondelete: str = None,
) -> Tuple[Any, List, Any]:
"""
Based on target to model to which relation leads to populates the type of the
@ -169,7 +169,7 @@ class ForeignKeyConstraint:
def ForeignKey( # noqa CFQ002
to: "ToType",
to: Type["T"],
*,
name: str = None,
unique: bool = False,
@ -179,7 +179,7 @@ def ForeignKey( # noqa CFQ002
onupdate: str = None,
ondelete: str = None,
**kwargs: Any,
) -> Any:
) -> "T":
"""
Despite a name it's a function that returns constructed ForeignKeyField.
This function is actually used in model declaration (as ormar.ForeignKey(ToModel)).

View File

@ -65,7 +65,7 @@ def ManyToMany(
unique: bool = False,
virtual: bool = False,
**kwargs: Any,
) -> Any:
):
"""
Despite a name it's a function that returns constructed ManyToManyField.
This function is actually used in model declaration

View File

@ -6,7 +6,7 @@ ass well as vast number of helper functions for pydantic, sqlalchemy and relatio
from ormar.models.newbasemodel import NewBaseModel # noqa I100
from ormar.models.model_row import ModelRow # noqa I100
from ormar.models.model import Model # noqa I100
from ormar.models.model import Model, T # noqa I100
from ormar.models.excludable import ExcludableItems # noqa I100
__all__ = ["NewBaseModel", "Model", "ModelRow", "ExcludableItems"]
__all__ = ["NewBaseModel", "Model", "ModelRow", "ExcludableItems", "T"]

View File

@ -117,7 +117,7 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None:
register_through_shortcut_fields(model_field=model_field)
adjust_through_many_to_many_model(model_field=model_field)
else:
model_field.to.Meta.model_fields[related_name] = ForeignKey(
model_field.to.Meta.model_fields[related_name] = ForeignKey( # type: ignore
model_field.owner,
real_name=related_name,
virtual=True,

View File

@ -26,14 +26,14 @@ def adjust_through_many_to_many_model(model_field: "ManyToManyField") -> None:
"""
parent_name = model_field.default_target_field_name()
child_name = model_field.default_source_field_name()
model_field.through.Meta.model_fields[parent_name] = ormar.ForeignKey(
model_fields = model_field.through.Meta.model_fields
model_fields[parent_name] = ormar.ForeignKey( # type: ignore
model_field.to,
real_name=parent_name,
ondelete="CASCADE",
owner=model_field.through,
)
model_field.through.Meta.model_fields[child_name] = ormar.ForeignKey(
model_fields[child_name] = ormar.ForeignKey( # type: ignore
model_field.owner,
real_name=child_name,
ondelete="CASCADE",

View File

@ -18,6 +18,7 @@ from sqlalchemy.sql.schema import ColumnCollectionConstraint
import ormar # noqa I100
from ormar import ModelDefinitionError # noqa I100
from ormar.exceptions import ModelError
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField
@ -44,6 +45,7 @@ from ormar.signals import Signal, SignalEmitter
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models import T
CONFIG_KEY = "Config"
PARSED_FIELDS_KEY = "__parsed_fields__"
@ -545,6 +547,15 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
field_name=field_name, model=new_model
)
new_model.Meta.alias_manager = alias_manager
new_model.objects = QuerySet(new_model)
return new_model
@property
def objects(cls: Type["T"]) -> "QuerySet[T]": # type: ignore
if cls.Meta.requires_ref_update:
raise ModelError(
f"Model {cls.get_name()} has not updated "
f"ForwardRefs. \nBefore using the model you "
f"need to call update_forward_refs()."
)
return QuerySet(model_cls=cls)

View File

@ -15,8 +15,6 @@ from ormar.models import NewBaseModel # noqa I100
from ormar.models.metaclass import ModelMeta
from ormar.models.model_row import ModelRow
if TYPE_CHECKING: # pragma nocover
from ormar import QuerySet
T = TypeVar("T", bound="Model")
@ -25,7 +23,6 @@ class Model(ModelRow):
__abstract__ = False
if TYPE_CHECKING: # pragma nocover
Meta: ModelMeta
objects: "QuerySet"
def __repr__(self) -> str: # pragma nocover
_repr = {k: getattr(self, k) for k, v in self.Meta.model_fields.items()}

View File

@ -1,13 +1,13 @@
from typing import (
Any,
Dict,
List,
Generic, List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
Type,
Union,
TypeVar, Union,
cast,
)
@ -26,19 +26,22 @@ from ormar.queryset.query import Query
if TYPE_CHECKING: # pragma no cover
from ormar import Model
from ormar.models import T
from ormar.models.metaclass import ModelMeta
from ormar.relations.querysetproxy import QuerysetProxy
from ormar.models.excludable import ExcludableItems
else:
T = TypeVar("T", bound="Model")
class QuerySet:
class QuerySet(Generic[T]):
"""
Main class to perform database queries, exposed on each model as objects attribute.
"""
def __init__( # noqa CFQ002
self,
model_cls: Optional[Type["Model"]] = None,
model_cls: Optional[Type["T"]] = None,
filter_clauses: List = None,
exclude_clauses: List = None,
select_related: List = None,
@ -62,21 +65,6 @@ class QuerySet:
self.order_bys = order_bys or []
self.limit_sql_raw = limit_raw_sql
def __get__(
self,
instance: Optional[Union["QuerySet", "QuerysetProxy"]],
owner: Union[Type["Model"], Type["QuerysetProxy"]],
) -> "QuerySet":
if issubclass(owner, ormar.Model):
if owner.Meta.requires_ref_update:
raise ModelError(
f"Model {owner.get_name()} has not updated "
f"ForwardRefs. \nBefore using the model you "
f"need to call update_forward_refs()."
)
owner = cast(Type["Model"], owner)
return self.__class__(model_cls=owner)
return self.__class__() # pragma: no cover
@property
def model_meta(self) -> "ModelMeta":
@ -91,7 +79,7 @@ class QuerySet:
return self.model_cls.Meta
@property
def model(self) -> Type["Model"]:
def model(self) -> Type["T"]:
"""
Shortcut to model class set on QuerySet.
@ -148,8 +136,8 @@ class QuerySet:
)
async def _prefetch_related_models(
self, models: List[Optional["Model"]], rows: List
) -> List[Optional["Model"]]:
self, models: List[Optional["T"]], rows: List
) -> List[Optional["T"]]:
"""
Performs prefetch query for selected models names.
@ -169,7 +157,7 @@ class QuerySet:
)
return await query.prefetch_related(models=models, rows=rows) # type: ignore
def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]:
def _process_query_result_rows(self, rows: List) -> List[Optional["T"]]:
"""
Process database rows and initialize ormar Model from each of the rows.
@ -190,7 +178,7 @@ class QuerySet:
]
if result_rows:
return self.model.merge_instances_list(result_rows) # type: ignore
return result_rows
return cast(List[Optional["T"]], result_rows)
def _resolve_filter_groups(self, groups: Any) -> List[FilterGroup]:
"""
@ -221,7 +209,7 @@ class QuerySet:
return filter_groups
@staticmethod
def check_single_result_rows_count(rows: Sequence[Optional["Model"]]) -> None:
def check_single_result_rows_count(rows: Sequence[Optional["T"]]) -> None:
"""
Verifies if the result has one and only one row.
@ -286,7 +274,7 @@ class QuerySet:
def filter( # noqa: A003
self, *args: Any, _exclude: bool = False, **kwargs: Any
) -> "QuerySet":
) -> "QuerySet[T]":
"""
Allows you to filter by any `Model` attribute/field
as well as to fetch instances, with a filter across an FK relationship.
@ -337,7 +325,7 @@ class QuerySet:
select_related=select_related,
)
def exclude(self, *args: Any, **kwargs: Any) -> "QuerySet": # noqa: A003
def exclude(self, *args: Any, **kwargs: Any) -> "QuerySet[T]": # noqa: A003
"""
Works exactly the same as filter and all modifiers (suffixes) are the same,
but returns a *not* condition.
@ -358,7 +346,7 @@ class QuerySet:
"""
return self.filter(_exclude=True, *args, **kwargs)
def select_related(self, related: Union[List, str]) -> "QuerySet":
def select_related(self, related: Union[List, str]) -> "QuerySet[T]":
"""
Allows to prefetch related models during the same query.
@ -381,7 +369,7 @@ class QuerySet:
related = sorted(list(set(list(self._select_related) + related)))
return self.rebuild_self(select_related=related,)
def prefetch_related(self, related: Union[List, str]) -> "QuerySet":
def prefetch_related(self, related: Union[List, str]) -> "QuerySet[T]":
"""
Allows to prefetch related models during query - but opposite to
`select_related` each subsequent model is fetched in a separate database query.
@ -407,7 +395,7 @@ class QuerySet:
def fields(
self, columns: Union[List, str, Set, Dict], _is_exclude: bool = False
) -> "QuerySet":
) -> "QuerySet[T]":
"""
With `fields()` you can select subset of model columns to limit the data load.
@ -461,7 +449,7 @@ class QuerySet:
return self.rebuild_self(excludable=excludable,)
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet":
def exclude_fields(self, columns: Union[List, str, Set, Dict]) -> "QuerySet[T]":
"""
With `exclude_fields()` you can select subset of model columns that will
be excluded to limit the data load.
@ -490,7 +478,7 @@ class QuerySet:
"""
return self.fields(columns=columns, _is_exclude=True)
def order_by(self, columns: Union[List, str]) -> "QuerySet":
def order_by(self, columns: Union[List, str]) -> "QuerySet[T]":
"""
With `order_by()` you can order the results from database based on your
choice of fields.
@ -680,7 +668,7 @@ class QuerySet:
)
return await self.database.execute(expr)
def paginate(self, page: int, page_size: int = 20) -> "QuerySet":
def paginate(self, page: int, page_size: int = 20) -> "QuerySet[T]":
"""
You can paginate the result which is a combination of offset and limit clauses.
Limit is set to page size and offset is set to (page-1) * page_size.
@ -699,7 +687,7 @@ class QuerySet:
query_offset = (page - 1) * page_size
return self.rebuild_self(limit_count=limit_count, offset=query_offset,)
def limit(self, limit_count: int, limit_raw_sql: bool = None) -> "QuerySet":
def limit(self, limit_count: int, limit_raw_sql: bool = None) -> "QuerySet[T]":
"""
You can limit the results to desired number of parent models.
@ -716,7 +704,7 @@ class QuerySet:
limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql
return self.rebuild_self(limit_count=limit_count, limit_raw_sql=limit_raw_sql,)
def offset(self, offset: int, limit_raw_sql: bool = None) -> "QuerySet":
def offset(self, offset: int, limit_raw_sql: bool = None) -> "QuerySet[T]":
"""
You can also offset the results by desired number of main models.
@ -733,7 +721,7 @@ class QuerySet:
limit_raw_sql = self.limit_sql_raw if limit_raw_sql is None else limit_raw_sql
return self.rebuild_self(offset=offset, limit_raw_sql=limit_raw_sql,)
async def first(self, **kwargs: Any) -> "Model":
async def first(self, **kwargs: Any) -> "T":
"""
Gets the first row from the db ordered by primary key column ascending.
@ -764,7 +752,7 @@ class QuerySet:
self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore
async def get(self, **kwargs: Any) -> "Model":
async def get(self, **kwargs: Any) -> "T":
"""
Get's the first row from the db meeting the criteria set by kwargs.
@ -803,7 +791,7 @@ class QuerySet:
self.check_single_result_rows_count(processed_rows)
return processed_rows[0] # type: ignore
async def get_or_create(self, **kwargs: Any) -> "Model":
async def get_or_create(self, **kwargs: Any) -> "T":
"""
Combination of create and get methods.
@ -821,7 +809,7 @@ class QuerySet:
except NoMatch:
return await self.create(**kwargs)
async def update_or_create(self, **kwargs: Any) -> "Model":
async def update_or_create(self, **kwargs: Any) -> "T":
"""
Updates the model, or in case there is no match in database creates a new one.
@ -838,7 +826,7 @@ class QuerySet:
model = await self.get(pk=kwargs[pk_name])
return await model.update(**kwargs)
async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003
async def all(self, **kwargs: Any) -> List[Optional["T"]]: # noqa: A003
"""
Returns all rows from a database for given model for set filter options.
@ -862,7 +850,7 @@ class QuerySet:
return result_rows
async def create(self, **kwargs: Any) -> "Model":
async def create(self, **kwargs: Any) -> "T":
"""
Creates the model instance, saves it in a database and returns the updates model
(with pk populated if not passed and autoincrement is set).
@ -905,7 +893,7 @@ class QuerySet:
)
return instance
async def bulk_create(self, objects: List["Model"]) -> None:
async def bulk_create(self, objects: List["T"]) -> None:
"""
Performs a bulk update in one database session to speed up the process.
@ -931,7 +919,7 @@ class QuerySet:
objt.set_save_status(True)
async def bulk_update( # noqa: CCR001
self, objects: List["Model"], columns: List[str] = None
self, objects: List["T"], columns: List[str] = None
) -> None:
"""
Performs bulk update in one database session to speed up the process.

View File

@ -50,7 +50,7 @@ class AliasTest(ormar.Model):
id: int = ormar.Integer(name="alias_id", primary_key=True)
name: str = ormar.String(name="alias_name", max_length=100)
nested: str = ormar.ForeignKey(AliasNested, name="nested_alias")
nested = ormar.ForeignKey(AliasNested, name="nested_alias")
class Toy(ormar.Model):

108
tests/test_types.py Normal file
View File

@ -0,0 +1,108 @@
from typing import Any, Optional, TYPE_CHECKING
import databases
import pytest
import sqlalchemy
import ormar
from ormar.relations.querysetproxy import QuerysetProxy
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class Publisher(ormar.Model):
class Meta(BaseMeta):
tablename = "publishers"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class Author(ormar.Model):
class Meta(BaseMeta):
tablename = "authors"
order_by = ["-name"]
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
publishers = ormar.ManyToMany(Publisher)
class Book(ormar.Model):
class Meta(BaseMeta):
tablename = "books"
order_by = ["year", "-ranking"]
id: int = ormar.Integer(primary_key=True)
author = ormar.ForeignKey(Author)
title: str = ormar.String(max_length=100)
year: int = ormar.Integer(nullable=True)
ranking: int = ormar.Integer(nullable=True)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def assert_type(book: Book):
print(book)
@pytest.mark.asyncio
async def test_types() -> None:
async with database:
query = Book.objects
publisher = await Publisher(name="Test publisher").save()
author = await Author.objects.create(name="Test Author")
await author.publishers.add(publisher)
author2 = await Author.objects.select_related("publishers").get()
publishers = author2.publishers
publisher2 = await Publisher.objects.select_related("authors").get()
authors = publisher2.authors
assert authors[0] == author
for author in authors:
if TYPE_CHECKING: # pragma: no cover
reveal_type(author) # iter of relation proxy
book = await Book.objects.create(title="Test", author=author)
book2 = await Book.objects.select_related("author").get()
books = await Book.objects.select_related("author").all()
author_books = await author.books.all()
assert book.author.name == "Test Author"
assert book2.author.name == "Test Author"
if TYPE_CHECKING: # pragma: no cover
reveal_type(publisher) # model method
reveal_type(publishers) # many to many
reveal_type(publishers[0]) # item in m2m list
# getting relation without __getattribute__
to_model = Publisher.Meta.model_fields['authors'].to
reveal_type(
publisher2._extract_related_model_instead_of_field("authors")
) # TODO: wrong
reveal_type(publisher.Meta.model_fields['authors'].to) # TODO: wrong
reveal_type(authors) # reverse many to many # TODO: wrong
reveal_type(book2) # queryset get
reveal_type(books) # queryset all
reveal_type(book) # queryset - create
reveal_type(query) # queryset itself
reveal_type(book.author) # fk
reveal_type(
author.books.queryset_proxy
) # queryset in querysetproxy # TODO: wrong
reveal_type(author.books) # reverse fk # TODO: wrong
reveal_type(author) # another test for queryset get different model
reveal_type(book.author.name) # field on related model
reveal_type(author_books) # querysetproxy for fk # TODO: wrong
reveal_type(author_books[0]) # item i qs proxy for fk # TODO: wrong
assert_type(book)