some refactors to reduce complexity
This commit is contained in:
@ -33,15 +33,23 @@ def register_relation_on_build(table_name: str, field: ForeignKey) -> None:
|
|||||||
relationship_manager.add_relation_type(field, table_name)
|
relationship_manager.add_relation_type(field, table_name)
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_field_not_already_registered(
|
||||||
|
child: Type["Model"], child_model_name: str, parent_model: Type["Model"]
|
||||||
|
) -> bool:
|
||||||
|
return (
|
||||||
|
child_model_name not in parent_model.__fields__
|
||||||
|
and child.get_name() not in parent_model.__fields__
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def expand_reverse_relationships(model: Type["Model"]) -> None:
|
def expand_reverse_relationships(model: Type["Model"]) -> None:
|
||||||
for model_field in model.Meta.model_fields.values():
|
for model_field in model.Meta.model_fields.values():
|
||||||
if issubclass(model_field, ForeignKeyField):
|
if issubclass(model_field, ForeignKeyField):
|
||||||
child_model_name = model_field.related_name or model.get_name() + "s"
|
child_model_name = model_field.related_name or model.get_name() + "s"
|
||||||
parent_model = model_field.to
|
parent_model = model_field.to
|
||||||
child = model
|
child = model
|
||||||
if (
|
if reverse_field_not_already_registered(
|
||||||
child_model_name not in parent_model.__fields__
|
child, child_model_name, parent_model
|
||||||
and child.get_name() not in parent_model.__fields__
|
|
||||||
):
|
):
|
||||||
register_reverse_model_fields(parent_model, child, child_model_name)
|
register_reverse_model_fields(parent_model, child, child_model_name)
|
||||||
|
|
||||||
@ -54,6 +62,16 @@ def register_reverse_model_fields(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_pk_column_validity(
|
||||||
|
field_name: str, field: BaseField, pkname: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
if pkname is not None:
|
||||||
|
raise ModelDefinitionError("Only one primary key column is allowed.")
|
||||||
|
if field.pydantic_only:
|
||||||
|
raise ModelDefinitionError("Primary key column cannot be pydantic only")
|
||||||
|
return field_name
|
||||||
|
|
||||||
|
|
||||||
def sqlalchemy_columns_from_model_fields(
|
def sqlalchemy_columns_from_model_fields(
|
||||||
model_fields: Dict, table_name: str
|
model_fields: Dict, table_name: str
|
||||||
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
|
) -> Tuple[Optional[str], List[sqlalchemy.Column]]:
|
||||||
@ -61,11 +79,7 @@ def sqlalchemy_columns_from_model_fields(
|
|||||||
pkname = None
|
pkname = None
|
||||||
for field_name, field in model_fields.items():
|
for field_name, field in model_fields.items():
|
||||||
if field.primary_key:
|
if field.primary_key:
|
||||||
if pkname is not None:
|
pkname = check_pk_column_validity(field_name, field, pkname)
|
||||||
raise ModelDefinitionError("Only one primary key column is allowed.")
|
|
||||||
if field.pydantic_only:
|
|
||||||
raise ModelDefinitionError("Primary key column cannot be pydantic only")
|
|
||||||
pkname = field_name
|
|
||||||
if not field.pydantic_only:
|
if not field.pydantic_only:
|
||||||
columns.append(field.get_column(field_name))
|
columns.append(field.get_column(field_name))
|
||||||
if issubclass(field, ForeignKeyField):
|
if issubclass(field, ForeignKeyField):
|
||||||
@ -73,6 +87,7 @@ def sqlalchemy_columns_from_model_fields(
|
|||||||
|
|
||||||
return pkname, columns
|
return pkname, columns
|
||||||
|
|
||||||
|
def populate_default
|
||||||
|
|
||||||
def populate_pydantic_default_values(attrs: Dict) -> Dict:
|
def populate_pydantic_default_values(attrs: Dict) -> Dict:
|
||||||
for field, type_ in attrs["__annotations__"].items():
|
for field, type_ in attrs["__annotations__"].items():
|
||||||
@ -92,7 +107,6 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
|
|||||||
class Config(BaseConfig):
|
class Config(BaseConfig):
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
# extra = Extra.allow
|
|
||||||
|
|
||||||
return Config
|
return Config
|
||||||
|
|
||||||
|
|||||||
@ -24,13 +24,14 @@ class ModelTableProxy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
|
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
|
||||||
for field in cls._extract_related_names():
|
for field in cls._extract_related_names():
|
||||||
if field in model_dict and model_dict.get(field) is not None:
|
field_value = model_dict.get(field, None)
|
||||||
|
if field_value is not None:
|
||||||
target_field = cls.Meta.model_fields[field]
|
target_field = cls.Meta.model_fields[field]
|
||||||
target_pkname = target_field.to.Meta.pkname
|
target_pkname = target_field.to.Meta.pkname
|
||||||
if isinstance(model_dict.get(field), ormar.Model):
|
if isinstance(field_value, ormar.Model):
|
||||||
model_dict[field] = getattr(model_dict.get(field), target_pkname)
|
model_dict[field] = getattr(field_value, target_pkname)
|
||||||
else:
|
else:
|
||||||
model_dict[field] = model_dict.get(field).get(target_pkname)
|
model_dict[field] = field_value.get(target_pkname)
|
||||||
return model_dict
|
return model_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -37,13 +37,21 @@ class QueryClause:
|
|||||||
def filter( # noqa: A003
|
def filter( # noqa: A003
|
||||||
self, **kwargs: Any
|
self, **kwargs: Any
|
||||||
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
||||||
filter_clauses = self.filter_clauses
|
|
||||||
select_related = list(self._select_related)
|
|
||||||
|
|
||||||
if kwargs.get("pk"):
|
if kwargs.get("pk"):
|
||||||
pk_name = self.model_cls.Meta.pkname
|
pk_name = self.model_cls.Meta.pkname
|
||||||
kwargs[pk_name] = kwargs.pop("pk")
|
kwargs[pk_name] = kwargs.pop("pk")
|
||||||
|
|
||||||
|
filter_clauses, select_related = self._populate_filter_clauses(**kwargs)
|
||||||
|
|
||||||
|
return filter_clauses, select_related
|
||||||
|
|
||||||
|
def _populate_filter_clauses(
|
||||||
|
self, **kwargs: Any
|
||||||
|
) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]:
|
||||||
|
filter_clauses = self.filter_clauses
|
||||||
|
select_related = list(self._select_related)
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
table_prefix = ""
|
table_prefix = ""
|
||||||
if "__" in key:
|
if "__" in key:
|
||||||
@ -73,24 +81,36 @@ class QueryClause:
|
|||||||
column = self.table.columns[key]
|
column = self.table.columns[key]
|
||||||
table = self.table
|
table = self.table
|
||||||
|
|
||||||
value, has_escaped_character = self._escape_characters_in_clause(op, value)
|
clause = self._process_column_clause_for_operator_and_value(
|
||||||
|
value, op, column, table, table_prefix
|
||||||
if isinstance(value, ormar.Model):
|
|
||||||
value = value.pk
|
|
||||||
|
|
||||||
op_attr = FILTER_OPERATORS[op]
|
|
||||||
clause = getattr(column, op_attr)(value)
|
|
||||||
clause = self._compile_clause(
|
|
||||||
clause,
|
|
||||||
column,
|
|
||||||
table,
|
|
||||||
table_prefix,
|
|
||||||
modifiers={"escape": "\\" if has_escaped_character else None},
|
|
||||||
)
|
)
|
||||||
filter_clauses.append(clause)
|
filter_clauses.append(clause)
|
||||||
|
|
||||||
return filter_clauses, select_related
|
return filter_clauses, select_related
|
||||||
|
|
||||||
|
def _process_column_clause_for_operator_and_value(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
op: str,
|
||||||
|
column: sqlalchemy.Column,
|
||||||
|
table: sqlalchemy.Table,
|
||||||
|
table_prefix: str,
|
||||||
|
) -> sqlalchemy.sql.expression.TextClause:
|
||||||
|
value, has_escaped_character = self._escape_characters_in_clause(op, value)
|
||||||
|
|
||||||
|
if isinstance(value, ormar.Model):
|
||||||
|
value = value.pk
|
||||||
|
|
||||||
|
op_attr = FILTER_OPERATORS[op]
|
||||||
|
clause = getattr(column, op_attr)(value)
|
||||||
|
clause = self._compile_clause(
|
||||||
|
clause,
|
||||||
|
column,
|
||||||
|
table,
|
||||||
|
table_prefix,
|
||||||
|
modifiers={"escape": "\\" if has_escaped_character else None},
|
||||||
|
)
|
||||||
|
return clause
|
||||||
|
|
||||||
def _determine_filter_target_table(
|
def _determine_filter_target_table(
|
||||||
self, related_parts: List[str], select_related: List[str]
|
self, related_parts: List[str], select_related: List[str]
|
||||||
) -> Tuple[List[str], str, "Model"]:
|
) -> Tuple[List[str], str, "Model"]:
|
||||||
|
|||||||
Reference in New Issue
Block a user