diff --git a/.coverage b/.coverage index 3090f87..b014efa 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.gitignore b/.gitignore index f58fa36..54536f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ p38venv .idea .pytest_cache +.mypy_cache *.pyc *.log test.db diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..d9b0283 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,5 @@ +[mypy] +python_version = 3.8 + +[mypy-sqlalchemy.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/ormar/fields/base.py b/ormar/fields/base.py index 1be4b63..522d708 100644 --- a/ormar/fields/base.py +++ b/ormar/fields/base.py @@ -1,7 +1,8 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Union +from typing import Any, List, Optional, TYPE_CHECKING, Union, Type import sqlalchemy from pydantic import Field, typing +from pydantic.fields import FieldInfo from ormar import ModelDefinitionError # noqa I101 @@ -15,6 +16,7 @@ class BaseField: column_type: sqlalchemy.Column constraints: List = [] + name: str primary_key: bool autoincrement: bool @@ -24,12 +26,14 @@ class BaseField: pydantic_only: bool virtual: bool = False choices: typing.Sequence + to: Type["Model"] + through: Type["Model"] default: Any server_default: Any @classmethod - def default_value(cls) -> Optional[Field]: + def default_value(cls) -> Optional[FieldInfo]: if cls.is_auto_primary_key(): return Field(default=None) if cls.has_default(): diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 1fac75e..176050f 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Generator import sqlalchemy @@ -7,7 +7,7 @@ from ormar.exceptions import RelationshipInstanceError from ormar.fields.base import BaseField if TYPE_CHECKING: # pragma no cover - from ormar.models import Model + from ormar.models import Model, NewBaseModel def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": @@ -23,16 +23,16 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def ForeignKey( # noqa CFQ002 - to: Type["Model"], - *, - name: str = None, - unique: bool = False, - nullable: bool = True, - related_name: str = None, - virtual: bool = False, - onupdate: str = None, - ondelete: str = None, -) -> Type[object]: + to: Type["Model"], + *, + name: str = None, + unique: bool = False, + nullable: bool = True, + related_name: str = None, + virtual: bool = False, + onupdate: str = None, + ondelete: str = None, +) -> Type["ForeignKeyField"]: fk_string = to.Meta.tablename + "." + to.Meta.pkname to_field = to.__fields__[to.Meta.pkname] namespace = dict( @@ -65,7 +65,7 @@ class ForeignKeyField(BaseField): virtual: bool @classmethod - def __get_validators__(cls) -> Callable: + def __get_validators__(cls) -> Generator: yield cls.validate @classmethod @@ -74,13 +74,13 @@ class ForeignKeyField(BaseField): @classmethod def _extract_model_from_sequence( - cls, value: List, child: "Model", to_register: bool - ) -> Union["Model", List["Model"]]: - return [cls.expand_relationship(val, child, to_register) for val in value] + cls, value: List, child: "Model", to_register: bool + ) -> List["Model"]: + return [cls.expand_relationship(val, child, to_register) for val in value] # type: ignore @classmethod def _register_existing_model( - cls, value: "Model", child: "Model", to_register: bool + cls, value: "Model", child: "Model", to_register: bool ) -> "Model": if to_register: cls.register_relation(value, child) @@ -88,7 +88,7 @@ class ForeignKeyField(BaseField): @classmethod def _construct_model_from_dict( - cls, value: dict, child: "Model", to_register: bool + cls, value: dict, child: "Model", to_register: bool ) -> "Model": if len(value.keys()) == 1 and list(value.keys())[0] == cls.to.Meta.pkname: value["__pk_only__"] = True @@ -99,7 +99,7 @@ class ForeignKeyField(BaseField): @classmethod def _construct_model_from_pk( - cls, value: Any, child: "Model", to_register: bool + cls, value: Any, child: "Model", to_register: bool ) -> "Model": if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( @@ -120,7 +120,7 @@ class ForeignKeyField(BaseField): @classmethod def expand_relationship( - cls, value: Any, child: "Model", to_register: bool = True + cls, value: Any, child: Union["Model", "NewBaseModel"], to_register: bool = True ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None if not cls.virtual else [] @@ -131,7 +131,7 @@ class ForeignKeyField(BaseField): "list": cls._extract_model_from_sequence, } - model = constructors.get( + model = constructors.get( # type: ignore value.__class__.__name__, cls._construct_model_from_pk )(value, child, to_register) return model diff --git a/ormar/fields/many_to_many.py b/ormar/fields/many_to_many.py index 89d5f62..0fed9e8 100644 --- a/ormar/fields/many_to_many.py +++ b/ormar/fields/many_to_many.py @@ -15,7 +15,7 @@ def ManyToMany( unique: bool = False, related_name: str = None, virtual: bool = False, -) -> Type[object]: +) -> Type["ManyToManyField"]: to_field = to.__fields__[to.Meta.pkname] namespace = dict( to=to, diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 388f229..c6627a7 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -18,10 +18,10 @@ def is_field_nullable( class ModelFieldFactory: - _bases = BaseField - _type = None + _bases: Any = BaseField + _type: Any = None - def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: + def __new__(cls, *args: Any, **kwargs: Any) -> Type[BaseField]: # type: ignore cls.validate(**kwargs) default = kwargs.pop("default", None) @@ -58,7 +58,7 @@ class String(ModelFieldFactory): _bases = (pydantic.ConstrainedStr, BaseField) _type = str - def __new__( # noqa CFQ002 + def __new__( # type: ignore # noqa CFQ002 cls, *, allow_blank: bool = False, @@ -68,7 +68,7 @@ class String(ModelFieldFactory): curtail_length: int = None, regex: str = None, **kwargs: Any - ) -> Type[str]: + ) -> Type[BaseField]: # type: ignore kwargs = { **kwargs, **{ @@ -96,14 +96,14 @@ class Integer(ModelFieldFactory): _bases = (pydantic.ConstrainedInt, BaseField) _type = int - def __new__( + def __new__( # type: ignore cls, *, minimum: int = None, maximum: int = None, multiple_of: int = None, **kwargs: Any - ) -> Type[int]: + ) -> Type[BaseField]: autoincrement = kwargs.pop("autoincrement", None) autoincrement = ( autoincrement @@ -131,9 +131,9 @@ class Text(ModelFieldFactory): _bases = (pydantic.ConstrainedStr, BaseField) _type = str - def __new__( + def __new__( # type: ignore cls, *, allow_blank: bool = False, strip_whitespace: bool = False, **kwargs: Any - ) -> Type[str]: + ) -> Type[BaseField]: kwargs = { **kwargs, **{ @@ -153,14 +153,14 @@ class Float(ModelFieldFactory): _bases = (pydantic.ConstrainedFloat, BaseField) _type = float - def __new__( + def __new__( # type: ignore cls, *, minimum: float = None, maximum: float = None, multiple_of: int = None, **kwargs: Any - ) -> Type[int]: + ) -> Type[BaseField]: kwargs = { **kwargs, **{ @@ -236,7 +236,7 @@ class Decimal(ModelFieldFactory): _bases = (pydantic.ConstrainedDecimal, BaseField) _type = decimal.Decimal - def __new__( # noqa CFQ002 + def __new__( # type: ignore # noqa CFQ002 cls, *, minimum: float = None, @@ -247,7 +247,7 @@ class Decimal(ModelFieldFactory): max_digits: int = None, decimal_places: int = None, **kwargs: Any - ) -> Type[decimal.Decimal]: + ) -> Type[BaseField]: kwargs = { **kwargs, **{ diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 16d3b84..4c45bc2 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -28,15 +28,19 @@ class ModelMeta: database: databases.Database columns: List[sqlalchemy.Column] pkname: str - model_fields: Dict[str, Union[BaseField, ForeignKey]] + model_fields: Dict[ + str, Union[Type[BaseField], Type[ForeignKeyField], Type[ManyToManyField]] + ] alias_manager: AliasManager -def register_relation_on_build(table_name: str, field: ForeignKey) -> None: +def register_relation_on_build(table_name: str, field: Type[ForeignKeyField]) -> None: alias_manager.add_relation_type(field.to.Meta.tablename, table_name) -def register_many_to_many_relation_on_build(table_name: str, field: ManyToMany) -> None: +def register_many_to_many_relation_on_build( + table_name: str, field: Type[ManyToManyField] +) -> None: alias_manager.add_relation_type(field.through.Meta.tablename, table_name) alias_manager.add_relation_type( field.through.Meta.tablename, field.to.Meta.tablename @@ -106,7 +110,7 @@ def create_pydantic_field( ) -> None: model_field.through.__fields__[field_name] = ModelField( name=field_name, - type_=Optional[model], + type_=model, model_config=model.__config__, required=False, class_validators={}, @@ -130,7 +134,7 @@ def create_and_append_m2m_fk( def check_pk_column_validity( - field_name: str, field: BaseField, pkname: str + field_name: str, field: BaseField, pkname: Optional[str] ) -> Optional[str]: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") @@ -218,6 +222,7 @@ def populate_meta_tablename_columns_and_pk( ) -> Type["Model"]: tablename = name.lower() + "s" new_model.Meta.tablename = new_model.Meta.tablename or tablename + pkname: Optional[str] if hasattr(new_model.Meta, "columns"): columns = new_model.Meta.table.columns @@ -226,12 +231,13 @@ def populate_meta_tablename_columns_and_pk( pkname, columns = sqlalchemy_columns_from_model_fields( new_model.Meta.model_fields, new_model.Meta.tablename ) + + if pkname is None: + raise ModelDefinitionError("Table has to have a primary key.") + new_model.Meta.columns = columns new_model.Meta.pkname = pkname - if not new_model.Meta.pkname: - raise ModelDefinitionError("Table has to have a primary key.") - return new_model @@ -253,8 +259,8 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]: return Config -def check_if_field_has_choices(field: BaseField) -> bool: - return hasattr(field, "choices") and field.choices +def check_if_field_has_choices(field: Type[BaseField]) -> bool: + return hasattr(field, "choices") and bool(field.choices) def model_initialized_and_has_model_fields(model: Type["Model"]) -> bool: @@ -287,7 +293,7 @@ def populate_choices_validators( # noqa CCR001 class ModelMetaclass(pydantic.main.ModelMetaclass): - def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: + def __new__(mcs: "ModelMetaclass", name: str, bases: Any, attrs: dict) -> "ModelMetaclass": # type: ignore attrs["Config"] = get_pydantic_base_orm_config() attrs["__name__"] = name attrs = extract_annotations_and_default_vals(attrs, bases) @@ -306,7 +312,7 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): field_name = new_model.Meta.pkname field = Integer(name=field_name, primary_key=True) attrs["__annotations__"][field_name] = field - populate_default_pydantic_field_value(field, field_name, attrs) + populate_default_pydantic_field_value(field, field_name, attrs) # type: ignore new_model = super().__new__( # type: ignore mcs, name, bases, attrs diff --git a/ormar/models/model.py b/ormar/models/model.py index ae2a590..1222d1e 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -1,5 +1,5 @@ import itertools -from typing import Any, List, Tuple, Union +from typing import Any, List, Dict, Optional import sqlalchemy from databases.backends.postgres import Record @@ -9,8 +9,8 @@ from ormar.fields.many_to_many import ManyToManyField from ormar.models import NewBaseModel # noqa I100 -def group_related_list(list_: List) -> dict: - test_dict = dict() +def group_related_list(list_: List) -> Dict: + test_dict: Dict[str, Any] = dict() grouped = itertools.groupby(list_, key=lambda x: x.split("__")[0]) for key, group in grouped: group_list = list(group) @@ -29,14 +29,14 @@ class Model(NewBaseModel): @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - related_models: Any = None, - previous_table: str = None, - ) -> Union["Model", Tuple["Model", dict]]: + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + related_models: Any = None, + previous_table: str = None, + ) -> Optional["Model"]: - item = {} + item: Dict[str, Any] = {} select_related = select_related or [] related_models = related_models or [] if select_related: @@ -44,17 +44,20 @@ class Model(NewBaseModel): # breakpoint() if ( - previous_table - and previous_table in cls.Meta.model_fields - and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) + previous_table + and previous_table in cls.Meta.model_fields + and issubclass(cls.Meta.model_fields[previous_table], ManyToManyField) ): previous_table = cls.Meta.model_fields[ previous_table ].through.Meta.tablename - table_prefix = cls.Meta.alias_manager.resolve_relation_join( - previous_table, cls.Meta.table.name - ) + if previous_table: + table_prefix = cls.Meta.alias_manager.resolve_relation_join( + previous_table, cls.Meta.table.name + ) + else: + table_prefix = '' previous_table = cls.Meta.table.name item = cls.populate_nested_models_from_row( @@ -67,11 +70,11 @@ class Model(NewBaseModel): @classmethod def populate_nested_models_from_row( - cls, - item: dict, - row: sqlalchemy.engine.ResultProxy, - related_models: Any, - previous_table: sqlalchemy.Table, + cls, + item: dict, + row: sqlalchemy.engine.ResultProxy, + related_models: Any, + previous_table: sqlalchemy.Table, ) -> dict: for related in related_models: if isinstance(related_models, dict) and related_models[related]: @@ -90,7 +93,7 @@ class Model(NewBaseModel): @classmethod def extract_prefixed_table_columns( # noqa CCR001 - cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str + cls, item: dict, row: sqlalchemy.engine.result.ResultProxy, table_prefix: str ) -> dict: for column in cls.Meta.table.columns: if column.name not in item: @@ -106,7 +109,7 @@ class Model(NewBaseModel): async def save(self) -> "Model": self_fields = self._extract_model_db_fields() - if not self.pk and self.Meta.model_fields.get(self.Meta.pkname).autoincrement: + if not self.pk and self.Meta.model_fields[self.Meta.pkname].autoincrement: self_fields.pop(self.Meta.pkname, None) self_fields = self.objects._populate_default_values(self_fields) expr = self.Meta.table.insert() @@ -138,5 +141,7 @@ class Model(NewBaseModel): async def load(self) -> "Model": expr = self.Meta.table.select().where(self.pk_column == self.pk) row = await self.Meta.database.fetch_one(expr) + if not row: # pragma nocover + raise ValueError('Instance was deleted from database and cannot be refreshed') self.from_dict(dict(row)) return self diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 470ac57..cf4d784 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -1,5 +1,5 @@ import inspect -from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union +from typing import List, Optional, Set, TYPE_CHECKING, Type, TypeVar, Union, Dict import ormar from ormar.exceptions import RelationshipInstanceError @@ -9,6 +9,7 @@ from ormar.models.metaclass import ModelMeta if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.models import NewBaseModel Field = TypeVar("Field", bound=BaseField) @@ -17,10 +18,10 @@ class ModelTableProxy: if TYPE_CHECKING: # pragma no cover Meta: ModelMeta - def dict(): # noqa A003 + def dict(self): # noqa A003 raise NotImplementedError # pragma no cover - def _extract_own_model_fields(self) -> dict: + def _extract_own_model_fields(self) -> Dict: related_names = self.extract_related_names() self_fields = {k: v for k, v in self.dict().items() if k not in related_names} return self_fields @@ -34,7 +35,7 @@ class ModelTableProxy: return self_fields @classmethod - def substitute_models_with_pks(cls, model_dict: dict) -> dict: + def substitute_models_with_pks(cls, model_dict: Dict) -> Dict: for field in cls.extract_related_names(): field_value = model_dict.get(field, None) if field_value is not None: @@ -80,7 +81,7 @@ class ModelTableProxy: related_names.add(name) return related_names - def _extract_model_db_fields(self) -> dict: + def _extract_model_db_fields(self) -> Dict: self_fields = self._extract_own_model_fields() self_fields = { k: v for k, v in self_fields.items() if k in self.Meta.table.columns @@ -92,7 +93,9 @@ class ModelTableProxy: return self_fields @staticmethod - def resolve_relation_name(item: "Model", related: "Model") -> Optional[str]: + def resolve_relation_name( + item: Union["NewBaseModel", Type["NewBaseModel"]], related: Union["NewBaseModel", Type["NewBaseModel"]] + ) -> str: for name, field in item.Meta.model_fields.items(): if issubclass(field, ForeignKeyField): # fastapi is creating clones of response model @@ -100,11 +103,14 @@ class ModelTableProxy: # so we need to compare Meta too as this one is copied as is if field.to == related.__class__ or field.to.Meta == related.Meta: return name + raise ValueError( + f"No relation between {item.get_name()} and {related.get_name()}" + ) # pragma nocover @staticmethod def resolve_relation_field( item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] - ) -> Type[Field]: + ) -> Union[Type[BaseField], Type[ForeignKeyField]]: name = ModelTableProxy.resolve_relation_name(item, related) to_field = item.Meta.model_fields.get(name) if not to_field: # pragma no cover @@ -116,7 +122,7 @@ class ModelTableProxy: @classmethod def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: - merged_rows = [] + merged_rows: List["Model"] = [] for index, model in enumerate(result_rows): if index > 0 and model.pk == merged_rows[-1].pk: merged_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) diff --git a/ormar/models/newbasemodel.py b/ormar/models/newbasemodel.py index 08294e8..af00c13 100644 --- a/ormar/models/newbasemodel.py +++ b/ormar/models/newbasemodel.py @@ -3,13 +3,13 @@ import uuid from typing import ( AbstractSet, Any, + Callable, Dict, List, Mapping, Optional, TYPE_CHECKING, Type, - TypeVar, Union, ) @@ -39,7 +39,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass __slots__ = ("_orm_id", "_orm_saved", "_orm") if TYPE_CHECKING: # pragma no cover - __model_fields__: Dict[str, TypeVar[BaseField]] + __model_fields__: Dict[str, Type[BaseField]] __table__: sqlalchemy.Table __fields__: Dict[str, pydantic.fields.ModelField] __pydantic_model__: Type[BaseModel] @@ -84,7 +84,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass for k, v in kwargs.items() } - values, fields_set, validation_error = pydantic.validate_model(self, kwargs) + values, fields_set, validation_error = pydantic.validate_model(self, kwargs) # type: ignore if validation_error and not pk_only: raise validation_error @@ -134,13 +134,14 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass ) -> Optional[Union["Model", List["Model"]]]: if item in self._orm: return self._orm.get(item) + return None - def __eq__(self, other: "Model") -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, NewBaseModel): return self.__same__(other) return super().__eq__(other) # pragma no cover - def __same__(self, other: "Model") -> bool: + def __same__(self, other: "NewBaseModel") -> bool: return ( self._orm_id == other._orm_id or self.dict() == other.dict() @@ -205,19 +206,19 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass dict_instance[field] = None return dict_instance - def from_dict(self, value_dict: Dict) -> "Model": + def from_dict(self, value_dict: Dict) -> "NewBaseModel": for key, value in value_dict.items(): setattr(self, key, value) return self - def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: + def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, Dict]: if not self._is_conversion_to_json_needed(column_name): return value condition = ( isinstance(value, str) if op == "loads" else not isinstance(value, str) ) - operand = json.loads if op == "loads" else json.dumps + operand: Callable[[Any], Any] = json.loads if op == "loads" else json.dumps if condition: try: @@ -227,4 +228,4 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass return value def _is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json + return self.Meta.model_fields[column_name].__type__ == pydantic.Json diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index eae086c..6b9d5a4 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type import sqlalchemy from sqlalchemy import text @@ -118,7 +118,7 @@ class QueryClause: def _determine_filter_target_table( self, related_parts: List[str], select_related: List[str] - ) -> Tuple[List[str], str, "Model"]: + ) -> Tuple[List[str], str, Type["Model"]]: table_prefix = "" model_cls = self.model_cls @@ -168,9 +168,7 @@ class QueryClause: return clause @staticmethod - def _escape_characters_in_clause( - op: str, value: Union[str, "Model"] - ) -> Tuple[str, bool]: + def _escape_characters_in_clause(op: str, value: Any) -> Tuple[Any, bool]: has_escaped_character = False if op not in [ diff --git a/ormar/queryset/join.py b/ormar/queryset/join.py index fa6ed74..4b70636 100644 --- a/ormar/queryset/join.py +++ b/ormar/queryset/join.py @@ -22,8 +22,8 @@ class SqlJoin: self, used_aliases: List, select_from: sqlalchemy.sql.select, - order_bys: List, - columns: List, + order_bys: List[sqlalchemy.sql.elements.TextClause], + columns: List[sqlalchemy.Column], ) -> None: self.used_aliases = used_aliases self.select_from = select_from diff --git a/ormar/queryset/limit_query.py b/ormar/queryset/limit_query.py index 2de7950..af59326 100644 --- a/ormar/queryset/limit_query.py +++ b/ormar/queryset/limit_query.py @@ -1,8 +1,10 @@ +from typing import Optional + import sqlalchemy class LimitQuery: - def __init__(self, limit_count: int) -> None: + def __init__(self, limit_count: Optional[int]) -> None: self.limit_count = limit_count def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: diff --git a/ormar/queryset/offset_query.py b/ormar/queryset/offset_query.py index bca365b..ce87296 100644 --- a/ormar/queryset/offset_query.py +++ b/ormar/queryset/offset_query.py @@ -1,8 +1,10 @@ +from typing import Optional + import sqlalchemy class OffsetQuery: - def __init__(self, query_offset: int) -> None: + def __init__(self, query_offset: Optional[int]) -> None: self.query_offset = query_offset def apply(self, expr: sqlalchemy.sql.select) -> sqlalchemy.sql.select: diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index b07612a..1900323 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -1,4 +1,4 @@ -from typing import List, TYPE_CHECKING, Tuple, Type +from typing import List, TYPE_CHECKING, Tuple, Type, Optional import sqlalchemy from sqlalchemy import text @@ -18,8 +18,8 @@ class Query: filter_clauses: List, exclude_clauses: List, select_related: List, - limit_count: int, - offset: int, + limit_count: Optional[int], + offset: Optional[int], ) -> None: self.query_offset = offset self.limit_count = limit_count @@ -30,11 +30,11 @@ class Query: self.model_cls = model_cls self.table = self.model_cls.Meta.table - self.used_aliases = [] + self.used_aliases: List[str] = [] - self.select_from = None - self.columns = None - self.order_bys = None + self.select_from: List[str] = [] + self.columns = [sqlalchemy.Column] + self.order_bys: List[sqlalchemy.sql.elements.TextClause] = [] @property def prefixed_pk_name(self) -> str: @@ -89,7 +89,7 @@ class Query: return expr def _reset_query_parameters(self) -> None: - self.select_from = None - self.columns = None - self.order_bys = None + self.select_from = [] + self.columns = [] + self.order_bys = [] self.used_aliases = [] diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index c3c1b7d..b6c17f5 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, List, Mapping, TYPE_CHECKING, Type, Union, Optional import databases import sqlalchemy @@ -13,17 +13,18 @@ from ormar.queryset.query import Query if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.models.metaclass import ModelMeta class QuerySet: def __init__( # noqa CFQ002 - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - exclude_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + exclude_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -36,47 +37,60 @@ class QuerySet: def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": return self.__class__(model_cls=owner) - def _process_query_result_rows(self, rows: List[Mapping]) -> List["Model"]: + @property + def model_meta(self) -> "ModelMeta": + if not self.model_cls: # pragma nocover + raise ValueError("Model class of QuerySet is not initialized") + return self.model_cls.Meta + + @property + def model(self) -> Type["Model"]: + if not self.model_cls: # pragma nocover + raise ValueError("Model class of QuerySet is not initialized") + return self.model_cls + + def _process_query_result_rows(self, rows: List) -> List[Optional["Model"]]: result_rows = [ - self.model_cls.from_row(row, select_related=self._select_related) + self.model.from_row(row, select_related=self._select_related) for row in rows ] - rows = self.model_cls.merge_instances_list(result_rows) - return rows + if result_rows: + return self.model.merge_instances_list(result_rows) # type: ignore + return result_rows def _populate_default_values(self, new_kwargs: dict) -> dict: - for field_name, field in self.model_cls.Meta.model_fields.items(): + for field_name, field in self.model_meta.model_fields.items(): if field_name not in new_kwargs and field.has_default(): new_kwargs[field_name] = field.get_default() return new_kwargs def _remove_pk_from_kwargs(self, new_kwargs: dict) -> dict: - pkname = self.model_cls.Meta.pkname - pk = self.model_cls.Meta.model_fields[pkname] + pkname = self.model_meta.pkname + pk = self.model_meta.model_fields[pkname] if new_kwargs.get(pkname, ormar.Undefined) is None and ( - pk.nullable or pk.autoincrement + pk.nullable or pk.autoincrement ): del new_kwargs[pkname] return new_kwargs @staticmethod - def check_single_result_rows_count(rows: List["Model"]) -> None: - if not rows: + def check_single_result_rows_count(rows: List[Optional["Model"]]) -> None: + if not rows or rows[0] is None: raise NoMatch() if len(rows) > 1: raise MultipleMatches() @property def database(self) -> databases.Database: - return self.model_cls.Meta.database + return self.model_meta.database @property def table(self) -> sqlalchemy.Table: - return self.model_cls.Meta.table + return self.model_meta.table def build_select_expression(self) -> sqlalchemy.sql.select: qry = Query( - model_cls=self.model_cls, + model_cls=self.model, select_related=self._select_related, filter_clauses=self.filter_clauses, exclude_clauses=self.exclude_clauses, @@ -89,7 +103,7 @@ class QuerySet: def filter(self, _exclude: bool = False, **kwargs: Any) -> "QuerySet": # noqa: A003 qryclause = QueryClause( - model_cls=self.model_cls, + model_cls=self.model, select_related=self._select_related, filter_clauses=self.filter_clauses, ) @@ -102,7 +116,7 @@ class QuerySet: filter_clauses = filter_clauses return self.__class__( - model_cls=self.model_cls, + model_cls=self.model, filter_clauses=filter_clauses, exclude_clauses=exclude_clauses, select_related=select_related, @@ -113,13 +127,13 @@ class QuerySet: def exclude(self, **kwargs: Any) -> "QuerySet": # noqa: A003 return self.filter(_exclude=True, **kwargs) - def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": - if not isinstance(related, (list, tuple)): + def select_related(self, related: Union[List, str]) -> "QuerySet": + if not isinstance(related, list): related = [related] related = list(set(list(self._select_related) + related)) return self.__class__( - model_cls=self.model_cls, + model_cls=self.model, filter_clauses=self.filter_clauses, exclude_clauses=self.exclude_clauses, select_related=related, @@ -138,7 +152,7 @@ class QuerySet: return await self.database.fetch_val(expr) async def update(self, each: bool = False, **kwargs: Any) -> int: - self_fields = self.model_cls.extract_db_own_fields() + self_fields = self.model.extract_db_own_fields() updates = {k: v for k, v in kwargs.items() if k in self_fields} if not each and not self.filter_clauses: raise QueryDefinitionError( @@ -165,7 +179,7 @@ class QuerySet: def limit(self, limit_count: int) -> "QuerySet": return self.__class__( - model_cls=self.model_cls, + model_cls=self.model, filter_clauses=self.filter_clauses, exclude_clauses=self.exclude_clauses, select_related=self._select_related, @@ -175,7 +189,7 @@ class QuerySet: def offset(self, offset: int) -> "QuerySet": return self.__class__( - model_cls=self.model_cls, + model_cls=self.model, filter_clauses=self.filter_clauses, exclude_clauses=self.exclude_clauses, select_related=self._select_related, @@ -189,7 +203,7 @@ class QuerySet: rows = await self.limit(1).all() self.check_single_result_rows_count(rows) - return rows[0] + return rows[0] # type: ignore async def get(self, **kwargs: Any) -> "Model": if kwargs: @@ -200,9 +214,9 @@ class QuerySet: expr = expr.limit(2) rows = await self.database.fetch_all(expr) - rows = self._process_query_result_rows(rows) - self.check_single_result_rows_count(rows) - return rows[0] + processed_rows = self._process_query_result_rows(rows) + self.check_single_result_rows_count(processed_rows) + return processed_rows[0] # type: ignore async def get_or_create(self, **kwargs: Any) -> "Model": try: @@ -211,7 +225,7 @@ class QuerySet: return await self.create(**kwargs) async def update_or_create(self, **kwargs: Any) -> "Model": - pk_name = self.model_cls.Meta.pkname + pk_name = self.model_meta.pkname if "pk" in kwargs: kwargs[pk_name] = kwargs.pop("pk") if pk_name not in kwargs or kwargs.get(pk_name) is None: @@ -219,7 +233,7 @@ class QuerySet: model = await self.get(pk=kwargs[pk_name]) return await model.update(**kwargs) - async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 + async def all(self, **kwargs: Any) -> List[Optional["Model"]]: # noqa: A003 if kwargs: return await self.filter(**kwargs).all() @@ -233,20 +247,20 @@ class QuerySet: new_kwargs = dict(**kwargs) new_kwargs = self._remove_pk_from_kwargs(new_kwargs) - new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) + new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = self._populate_default_values(new_kwargs) expr = self.table.insert() expr = expr.values(**new_kwargs) # Execute the insert, and return a new model instance. - instance = self.model_cls(**kwargs) + instance = self.model(**kwargs) pk = await self.database.execute(expr) - pk_name = self.model_cls.Meta.pkname + pk_name = self.model_meta.pkname if pk_name not in kwargs and pk_name in new_kwargs: - instance.pk = new_kwargs[self.model_cls.Meta.pkname] - if pk and isinstance(pk, self.model_cls.pk_type()): - setattr(instance, self.model_cls.Meta.pkname, pk) + instance.pk = new_kwargs[self.model_meta.pkname] + if pk and isinstance(pk, self.model.pk_type()): + setattr(instance, self.model_meta.pkname, pk) return instance async def bulk_create(self, objects: List["Model"]) -> None: @@ -254,7 +268,7 @@ class QuerySet: for objt in objects: new_kwargs = objt.dict() new_kwargs = self._remove_pk_from_kwargs(new_kwargs) - new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) + new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = self._populate_default_values(new_kwargs) ready_objects.append(new_kwargs) @@ -262,13 +276,15 @@ class QuerySet: await self.database.execute_many(expr, ready_objects) async def bulk_update( - self, objects: List["Model"], columns: List[str] = None + self, objects: List["Model"], columns: List[str] = None ) -> None: ready_objects = [] - pk_name = self.model_cls.Meta.pkname + pk_name = self.model_meta.pkname if not columns: - columns = self.model_cls.extract_db_own_fields().union( - self.model_cls.extract_related_names() + columns = list( + self.model.extract_db_own_fields().union( + self.model.extract_related_names() + ) ) if pk_name not in columns: @@ -279,13 +295,13 @@ class QuerySet: if pk_name not in new_kwargs or new_kwargs.get(pk_name) is None: raise QueryDefinitionError( "You cannot update unsaved objects. " - f"{self.model_cls.__name__} has to have {pk_name} filled." + f"{self.model.__name__} has to have {pk_name} filled." ) - new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs) + new_kwargs = self.model.substitute_models_with_pks(new_kwargs) new_kwargs = {"new_" + k: v for k, v in new_kwargs.items() if k in columns} ready_objects.append(new_kwargs) - pk_column = self.model_cls.Meta.table.c.get(pk_name) + pk_column = self.model_meta.table.c.get(pk_name) expr = self.table.update().where(pk_column == bindparam("new_" + pk_name)) expr = expr.values( **{k: bindparam("new_" + k) for k in columns if k != pk_name} diff --git a/ormar/relations/alias_manager.py b/ormar/relations/alias_manager.py index 64f7261..1b0dc67 100644 --- a/ormar/relations/alias_manager.py +++ b/ormar/relations/alias_manager.py @@ -1,7 +1,7 @@ import string import uuid from random import choices -from typing import List +from typing import List, Dict import sqlalchemy from sqlalchemy import text @@ -14,7 +14,7 @@ def get_table_alias() -> str: class AliasManager: def __init__(self) -> None: - self._aliases = dict() + self._aliases: Dict[str, str] = dict() @staticmethod def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: diff --git a/ormar/relations/querysetproxy.py b/ormar/relations/querysetproxy.py index 01066f0..8efa5d9 100644 --- a/ormar/relations/querysetproxy.py +++ b/ormar/relations/querysetproxy.py @@ -13,8 +13,8 @@ class QuerysetProxy: relation: "Relation" def __init__(self, relation: "Relation") -> None: - self.relation = relation - self.queryset = None + self.relation: Relation = relation + self.queryset: "QuerySet" def _assign_child_to_parent(self, child: "Model") -> None: owner = self.relation._owner diff --git a/ormar/relations/relation.py b/ormar/relations/relation.py index e0747fb..c8ce872 100644 --- a/ormar/relations/relation.py +++ b/ormar/relations/relation.py @@ -9,6 +9,7 @@ from ormar.relations.relation_proxy import RelationProxy if TYPE_CHECKING: # pragma no cover from ormar import Model from ormar.relations import RelationsManager + from ormar.models import NewBaseModel class RelationType(Enum): @@ -19,24 +20,26 @@ class RelationType(Enum): class Relation: def __init__( - self, - manager: "RelationsManager", - type_: RelationType, - to: Type["Model"], - through: Type["Model"] = None, + self, + manager: "RelationsManager", + type_: RelationType, + to: Type["Model"], + through: Type["Model"] = None, ) -> None: self.manager = manager - self._owner = manager.owner - self._type = type_ - self.to = to - self.through = through - self.related_models = ( + self._owner: "Model" = manager.owner + self._type: RelationType = type_ + self.to: Type["Model"] = to + self.through: Optional[Type["Model"]] = through + self.related_models: Optional[Union[RelationProxy, "Model"]] = ( RelationProxy(relation=self) if type_ in (RelationType.REVERSE, RelationType.MULTIPLE) else None ) def _find_existing(self, child: "Model") -> Optional[int]: + if not isinstance(self.related_models, RelationProxy): # pragma nocover + raise ValueError("Cannot find existing models in parent relation type") for ind, relation_child in enumerate(self.related_models[:]): try: if relation_child == child: @@ -52,7 +55,7 @@ class Relation: self._owner.__dict__[relation_name] = child else: if self._find_existing(child) is None: - self.related_models.append(child) + self.related_models.append(child) # type: ignore rel = self._owner.__dict__.get(relation_name, []) rel = rel or [] if not isinstance(rel, list): @@ -60,19 +63,19 @@ class Relation: rel.append(child) self._owner.__dict__[relation_name] = rel - def remove(self, child: "Model") -> None: + def remove(self, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None: relation_name = self._owner.resolve_relation_name(self._owner, child) if self._type == RelationType.PRIMARY: - if self.related_models.__same__(child): + if self.related_models == child: self.related_models = None del self._owner.__dict__[relation_name] else: position = self._find_existing(child) if position is not None: - self.related_models.pop(position) + self.related_models.pop(position) # type: ignore del self._owner.__dict__[relation_name][position] - def get(self) -> Union[List["Model"], "Model"]: + def get(self) -> Optional[Union[List["Model"], "Model"]]: return self.related_models def __repr__(self) -> str: # pragma no cover diff --git a/ormar/relations/relation_manager.py b/ormar/relations/relation_manager.py index eb0d7fe..ce9dd69 100644 --- a/ormar/relations/relation_manager.py +++ b/ormar/relations/relation_manager.py @@ -1,6 +1,7 @@ -from typing import List, Optional, TYPE_CHECKING, Type, Union +from typing import List, Optional, TYPE_CHECKING, Type, Union, Dict from weakref import proxy +from ormar.fields import BaseField from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.many_to_many import ManyToManyField from ormar.relations.relation import Relation, RelationType @@ -11,25 +12,28 @@ from ormar.relations.utils import ( if TYPE_CHECKING: # pragma no cover from ormar import Model + from ormar.models import NewBaseModel class RelationsManager: def __init__( - self, related_fields: List[Type[ForeignKeyField]] = None, owner: "Model" = None + self, + related_fields: List[Type[ForeignKeyField]] = None, + owner: "NewBaseModel" = None, ) -> None: self.owner = proxy(owner) self._related_fields = related_fields or [] self._related_names = [field.name for field in self._related_fields] - self._relations = dict() + self._relations: Dict[str, Relation] = dict() for field in self._related_fields: self._add_relation(field) - def _get_relation_type(self, field: Type[ForeignKeyField]) -> RelationType: + def _get_relation_type(self, field: Type[BaseField]) -> RelationType: if issubclass(field, ManyToManyField): return RelationType.MULTIPLE return RelationType.PRIMARY if not field.virtual else RelationType.REVERSE - def _add_relation(self, field: Type[ForeignKeyField]) -> None: + def _add_relation(self, field: Type[BaseField]) -> None: self._relations[field.name] = Relation( manager=self, type_=self._get_relation_type(field), @@ -44,15 +48,17 @@ class RelationsManager: relation = self._relations.get(name, None) if relation is not None: return relation.get() + return None # pragma nocover def _get(self, name: str) -> Optional[Relation]: relation = self._relations.get(name, None) if relation is not None: return relation + return None @staticmethod def add(parent: "Model", child: "Model", child_name: str, virtual: bool) -> None: - to_field = child.resolve_relation_field(child, parent) + to_field: Type[BaseField] = child.resolve_relation_field(child, parent) (parent, child, child_name, to_name,) = get_relations_sides_and_names( to_field, parent, child, child_name, virtual @@ -61,18 +67,22 @@ class RelationsManager: parent_relation = parent._orm._get(child_name) if not parent_relation: parent_relation = register_missing_relation(parent, child, child_name) - parent_relation.add(child) - child._orm._get(to_name).add(parent) + parent_relation.add(child) # type: ignore - def remove(self, name: str, child: "Model") -> None: + child_relation = child._orm._get(to_name) + if child_relation: + child_relation.add(parent) + + def remove(self, name: str, child: Union["NewBaseModel", Type["NewBaseModel"]]) -> None: relation = self._get(name) - relation.remove(child) + if relation: + relation.remove(child) @staticmethod - def remove_parent(item: "Model", name: Union[str, "Model"]) -> None: + def remove_parent(item: Union["NewBaseModel", Type["NewBaseModel"]], name: "Model") -> None: related_model = name - name = item.resolve_relation_name(item, related_model) - if name in item._orm: + rel_name = item.resolve_relation_name(item, related_model) + if rel_name in item._orm: relation_name = item.resolve_relation_name(related_model, item) - item._orm.remove(name, related_model) + item._orm.remove(rel_name, related_model) related_model._orm.remove(relation_name, item) diff --git a/ormar/relations/relation_proxy.py b/ormar/relations/relation_proxy.py index 8b0ff34..3863679 100644 --- a/ormar/relations/relation_proxy.py +++ b/ormar/relations/relation_proxy.py @@ -13,22 +13,30 @@ if TYPE_CHECKING: # pragma no cover class RelationProxy(list): def __init__(self, relation: "Relation") -> None: super(RelationProxy, self).__init__() - self.relation = relation - self._owner = self.relation.manager.owner + self.relation: Relation = relation + self._owner: "Model" = self.relation.manager.owner self.queryset_proxy = QuerysetProxy(relation=self.relation) def __getattribute__(self, item: str) -> Any: if item in ["count", "clear"]: - if not self.queryset_proxy.queryset: - self.queryset_proxy.queryset = self._set_queryset() + self._initialize_queryset() return getattr(self.queryset_proxy, item) return super().__getattribute__(item) def __getattr__(self, item: str) -> Any: - if not self.queryset_proxy.queryset: - self.queryset_proxy.queryset = self._set_queryset() + self._initialize_queryset() return getattr(self.queryset_proxy, item) + def _initialize_queryset(self) -> None: + if not self._check_if_queryset_is_initialized(): + self.queryset_proxy.queryset = self._set_queryset() + + def _check_if_queryset_is_initialized(self) -> bool: + return ( + hasattr(self.queryset_proxy, "queryset") + and self.queryset_proxy.queryset is not None + ) + def _set_queryset(self) -> "QuerySet": owner_table = self.relation._owner.Meta.tablename pkname = self.relation._owner.Meta.pkname @@ -45,10 +53,15 @@ class RelationProxy(list): ) return queryset - async def remove(self, item: "Model") -> None: + async def remove(self, item: "Model") -> None: # type: ignore super().remove(item) rel_name = item.resolve_relation_name(item, self._owner) - item._orm._get(rel_name).remove(self._owner) + relation = item._orm._get(rel_name) + if relation is None: # pragma nocover + raise ValueError( + f"{self._owner.get_name()} does not have relation {rel_name}" + ) + relation.remove(self._owner) if self.relation._type == ormar.RelationType.MULTIPLE: await self.queryset_proxy.delete_through_instance(item) diff --git a/ormar/relations/utils.py b/ormar/relations/utils.py index a182fec..ad30a28 100644 --- a/ormar/relations/utils.py +++ b/ormar/relations/utils.py @@ -1,8 +1,8 @@ -from typing import TYPE_CHECKING, Tuple, Type +from typing import TYPE_CHECKING, Tuple, Type, Optional from weakref import proxy import ormar -from ormar.fields.foreign_key import ForeignKeyField +from ormar.fields import BaseField from ormar.fields.many_to_many import ManyToManyField from ormar.relations import Relation @@ -12,7 +12,7 @@ if TYPE_CHECKING: # pragma no cover def register_missing_relation( parent: "Model", child: "Model", child_name: str -) -> Relation: +) -> Optional[Relation]: ormar.models.expand_reverse_relationships(child.__class__) name = parent.resolve_relation_name(parent, child) field = parent.Meta.model_fields[name] @@ -22,7 +22,7 @@ def register_missing_relation( def get_relations_sides_and_names( - to_field: Type[ForeignKeyField], + to_field: Type[BaseField], parent: "Model", child: "Model", child_name: str,