From 22b42ff6fc1b8bc442d73b43d8ee2bfdbc3c0b93 Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 27 Aug 2020 18:56:21 +0200 Subject: [PATCH] some refactors to reduce complexity --- .coverage | Bin 53248 -> 53248 bytes ormar/models/metaclass.py | 32 ++++++++++++++++------- ormar/models/modelproxy.py | 9 ++++--- ormar/queryset/clause.py | 52 +++++++++++++++++++++++++------------ 4 files changed, 64 insertions(+), 29 deletions(-) diff --git a/.coverage b/.coverage index 350ca4de7553b2794dec37b4a604f3e941d04af5..49b8d33146f27bdd7effa92b008068da74826857 100644 GIT binary patch delta 172 zcmV;d08{^fpaX!Q1F$MD1u{4~G&8d=FV|2uvj7kI59$xz57Q6G54aDY4|fl54@?g) z4<`=^4(Sfg4z3QE4uKAN4r>lM4jc{l4dD&U4YRWm5O)oew~g=%?eE|IxA#5v+wZ-T zvW`0-9u5Qn2|fKt4do5c4Y#uq5PJ=ivyJcz_V;i9+xs5-?e|`j zu8un%91a8l2|5ni^1r<8yZ8Mw*Z$R=<6r-t@7t;C-vyHik0u)#3j_fPGz None: relationship_manager.add_relation_type(field, table_name) +def reverse_field_not_already_registered( + child: Type["Model"], child_model_name: str, parent_model: Type["Model"] +) -> bool: + return ( + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ + ) + + def expand_reverse_relationships(model: Type["Model"]) -> None: for model_field in model.Meta.model_fields.values(): if issubclass(model_field, ForeignKeyField): child_model_name = model_field.related_name or model.get_name() + "s" parent_model = model_field.to child = model - if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + if reverse_field_not_already_registered( + child, child_model_name, parent_model ): register_reverse_model_fields(parent_model, child, child_model_name) @@ -54,6 +62,16 @@ def register_reverse_model_fields( ) +def check_pk_column_validity( + field_name: str, field: BaseField, pkname: str +) -> Optional[str]: + if pkname is not None: + raise ModelDefinitionError("Only one primary key column is allowed.") + if field.pydantic_only: + raise ModelDefinitionError("Primary key column cannot be pydantic only") + return field_name + + def sqlalchemy_columns_from_model_fields( model_fields: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column]]: @@ -61,11 +79,7 @@ def sqlalchemy_columns_from_model_fields( pkname = None for field_name, field in model_fields.items(): if field.primary_key: - if pkname is not None: - raise ModelDefinitionError("Only one primary key column is allowed.") - if field.pydantic_only: - raise ModelDefinitionError("Primary key column cannot be pydantic only") - pkname = field_name + pkname = check_pk_column_validity(field_name, field, pkname) if not field.pydantic_only: columns.append(field.get_column(field_name)) if issubclass(field, ForeignKeyField): @@ -73,6 +87,7 @@ def sqlalchemy_columns_from_model_fields( return pkname, columns +def populate_default def populate_pydantic_default_values(attrs: Dict) -> Dict: for field, type_ in attrs["__annotations__"].items(): @@ -92,7 +107,6 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]: class Config(BaseConfig): orm_mode = True arbitrary_types_allowed = True - # extra = Extra.allow return Config diff --git a/ormar/models/modelproxy.py b/ormar/models/modelproxy.py index 8a99de8..81e1ea6 100644 --- a/ormar/models/modelproxy.py +++ b/ormar/models/modelproxy.py @@ -24,13 +24,14 @@ class ModelTableProxy: @classmethod def substitute_models_with_pks(cls, model_dict: dict) -> dict: for field in cls._extract_related_names(): - if field in model_dict and model_dict.get(field) is not None: + field_value = model_dict.get(field, None) + if field_value is not None: target_field = cls.Meta.model_fields[field] target_pkname = target_field.to.Meta.pkname - if isinstance(model_dict.get(field), ormar.Model): - model_dict[field] = getattr(model_dict.get(field), target_pkname) + if isinstance(field_value, ormar.Model): + model_dict[field] = getattr(field_value, target_pkname) else: - model_dict[field] = model_dict.get(field).get(target_pkname) + model_dict[field] = field_value.get(target_pkname) return model_dict @classmethod diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 01a181a..f7436ac 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -37,13 +37,21 @@ class QueryClause: def filter( # noqa: A003 self, **kwargs: Any ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: - filter_clauses = self.filter_clauses - select_related = list(self._select_related) if kwargs.get("pk"): pk_name = self.model_cls.Meta.pkname kwargs[pk_name] = kwargs.pop("pk") + filter_clauses, select_related = self._populate_filter_clauses(**kwargs) + + return filter_clauses, select_related + + def _populate_filter_clauses( + self, **kwargs: Any + ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + for key, value in kwargs.items(): table_prefix = "" if "__" in key: @@ -73,24 +81,36 @@ class QueryClause: column = self.table.columns[key] table = self.table - value, has_escaped_character = self._escape_characters_in_clause(op, value) - - if isinstance(value, ormar.Model): - value = value.pk - - op_attr = FILTER_OPERATORS[op] - clause = getattr(column, op_attr)(value) - clause = self._compile_clause( - clause, - column, - table, - table_prefix, - modifiers={"escape": "\\" if has_escaped_character else None}, + clause = self._process_column_clause_for_operator_and_value( + value, op, column, table, table_prefix ) filter_clauses.append(clause) - return filter_clauses, select_related + def _process_column_clause_for_operator_and_value( + self, + value: Any, + op: str, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + ) -> sqlalchemy.sql.expression.TextClause: + value, has_escaped_character = self._escape_characters_in_clause(op, value) + + if isinstance(value, ormar.Model): + value = value.pk + + op_attr = FILTER_OPERATORS[op] + clause = getattr(column, op_attr)(value) + clause = self._compile_clause( + clause, + column, + table, + table_prefix, + modifiers={"escape": "\\" if has_escaped_character else None}, + ) + return clause + def _determine_filter_target_table( self, related_parts: List[str], select_related: List[str] ) -> Tuple[List[str], str, "Model"]: