some refactors to reduce complexity

This commit is contained in:
collerek
2020-08-27 18:56:21 +02:00
parent 279d3966b1
commit 22b42ff6fc
4 changed files with 64 additions and 29 deletions

BIN
.coverage

Binary file not shown.

View File

@ -33,15 +33,23 @@ def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
relationship_manager.add_relation_type(field, table_name) relationship_manager.add_relation_type(field, table_name)
def reverse_field_not_already_registered(
child: Type["Model"], child_model_name: str, parent_model: Type["Model"]
) -> bool:
return (
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
)
def expand_reverse_relationships(model: Type["Model"]) -> None: def expand_reverse_relationships(model: Type["Model"]) -> None:
for model_field in model.Meta.model_fields.values(): for model_field in model.Meta.model_fields.values():
if issubclass(model_field, ForeignKeyField): if issubclass(model_field, ForeignKeyField):
child_model_name = model_field.related_name or model.get_name() + "s" child_model_name = model_field.related_name or model.get_name() + "s"
parent_model = model_field.to parent_model = model_field.to
child = model child = model
if ( if reverse_field_not_already_registered(
child_model_name not in parent_model.__fields__ child, child_model_name, parent_model
and child.get_name() not in parent_model.__fields__
): ):
register_reverse_model_fields(parent_model, child, child_model_name) register_reverse_model_fields(parent_model, child, child_model_name)
@ -54,6 +62,16 @@ def register_reverse_model_fields(
) )
def check_pk_column_validity(
field_name: str, field: BaseField, pkname: str
) -> Optional[str]:
if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only:
raise ModelDefinitionError("Primary key column cannot be pydantic only")
return field_name
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
model_fields: Dict, table_name: str model_fields: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
@ -61,11 +79,7 @@ def sqlalchemy_columns_from_model_fields(
pkname = None pkname = None
for field_name, field in model_fields.items(): for field_name, field in model_fields.items():
if field.primary_key: if field.primary_key:
if pkname is not None: pkname = check_pk_column_validity(field_name, field, pkname)
raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only:
raise ModelDefinitionError("Primary key column cannot be pydantic only")
pkname = field_name
if not field.pydantic_only: if not field.pydantic_only:
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
if issubclass(field, ForeignKeyField): if issubclass(field, ForeignKeyField):
@ -73,6 +87,7 @@ def sqlalchemy_columns_from_model_fields(
return pkname, columns return pkname, columns
def populate_default
def populate_pydantic_default_values(attrs: Dict) -> Dict: def populate_pydantic_default_values(attrs: Dict) -> Dict:
for field, type_ in attrs["__annotations__"].items(): for field, type_ in attrs["__annotations__"].items():
@ -92,7 +107,6 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class Config(BaseConfig): class Config(BaseConfig):
orm_mode = True orm_mode = True
arbitrary_types_allowed = True arbitrary_types_allowed = True
# extra = Extra.allow
return Config return Config

View File

@ -24,13 +24,14 @@ class ModelTableProxy:
@classmethod @classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict: def substitute_models_with_pks(cls, model_dict: dict) -> dict:
for field in cls._extract_related_names(): for field in cls._extract_related_names():
if field in model_dict and model_dict.get(field) is not None: field_value = model_dict.get(field, None)
if field_value is not None:
target_field = cls.Meta.model_fields[field] target_field = cls.Meta.model_fields[field]
target_pkname = target_field.to.Meta.pkname target_pkname = target_field.to.Meta.pkname
if isinstance(model_dict.get(field), ormar.Model): if isinstance(field_value, ormar.Model):
model_dict[field] = getattr(model_dict.get(field), target_pkname) model_dict[field] = getattr(field_value, target_pkname)
else: else:
model_dict[field] = model_dict.get(field).get(target_pkname) model_dict[field] = field_value.get(target_pkname)
return model_dict return model_dict
@classmethod @classmethod

View File

@ -37,13 +37,21 @@ class QueryClause:
def filter( # noqa: A003 def filter( # noqa: A003
self, **kwargs: Any self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
if kwargs.get("pk"): if kwargs.get("pk"):
pk_name = self.model_cls.Meta.pkname pk_name = self.model_cls.Meta.pkname
kwargs[pk_name] = kwargs.pop("pk") kwargs[pk_name] = kwargs.pop("pk")
filter_clauses, select_related = self._populate_filter_clauses(**kwargs)
return filter_clauses, select_related
def _populate_filter_clauses(
self, **kwargs: Any
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
filter_clauses = self.filter_clauses
select_related = list(self._select_related)
for key, value in kwargs.items(): for key, value in kwargs.items():
table_prefix = "" table_prefix = ""
if "__" in key: if "__" in key:
@ -73,6 +81,20 @@ class QueryClause:
column = self.table.columns[key] column = self.table.columns[key]
table = self.table table = self.table
clause = self._process_column_clause_for_operator_and_value(
value, op, column, table, table_prefix
)
filter_clauses.append(clause)
return filter_clauses, select_related
def _process_column_clause_for_operator_and_value(
self,
value: Any,
op: str,
column: sqlalchemy.Column,
table: sqlalchemy.Table,
table_prefix: str,
) -> sqlalchemy.sql.expression.TextClause:
value, has_escaped_character = self._escape_characters_in_clause(op, value) value, has_escaped_character = self._escape_characters_in_clause(op, value)
if isinstance(value, ormar.Model): if isinstance(value, ormar.Model):
@ -87,9 +109,7 @@ class QueryClause:
table_prefix, table_prefix,
modifiers={"escape": "\\" if has_escaped_character else None}, modifiers={"escape": "\\" if has_escaped_character else None},
) )
filter_clauses.append(clause) return clause
return filter_clauses, select_related
def _determine_filter_target_table( def _determine_filter_target_table(
self, related_parts: List[str], select_related: List[str] self, related_parts: List[str], select_related: List[str]