fix some code smells

This commit is contained in:
collerek
2020-08-09 08:59:36 +02:00
parent fa00f7b011
commit 22c4a0619c
6 changed files with 127 additions and 115 deletions

BIN
.coverage

Binary file not shown.

View File

@ -15,9 +15,9 @@
The `async-orm` package is an async ORM for Python, with support for Postgres, The `async-orm` package is an async ORM for Python, with support for Postgres,
MySQL, and SQLite. ORM is built with: MySQL, and SQLite. ORM is built with:
* [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`SQLAlchemy core`][sqlalchemy-core] for query building.
* [`databases`][databases] for cross-database async support. * [`databases`][databases] for cross-database async support.
* [`pydantic`][pydantic] for data validation. * [`pydantic`][pydantic] for data validation.
Because ORM is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide Because ORM is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide
database migrations. database migrations.
@ -26,7 +26,7 @@ The goal was to create a simple orm that can be used directly with [`fastapi`][f
Initial work was inspired by [`encode/orm`][encode/orm]. Initial work was inspired by [`encode/orm`][encode/orm].
The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks.
**aysn-orm is still under development: We recommend pinning any dependencies with `async-orm~=0.1`** **aysn-orm is still under development:** We recommend pinning any dependencies with `aorm~=0.0.1`
**Note**: Use `ipython` to try this from the console, since it supports `await`. **Note**: Use `ipython` to try this from the console, since it supports `await`.
@ -155,33 +155,33 @@ assert len(tracks) == 1
The following keyword arguments are supported on all field types. The following keyword arguments are supported on all field types.
* `primary_key` * `primary_key`
* `nullable` * `nullable`
* `default` * `default`
* `server_default` * `server_default`
* `index` * `index`
* `unique` * `unique`
All fields are required unless one of the following is set: All fields are required unless one of the following is set:
* `nullable` - Creates a nullable column. Sets the default to `None`. * `nullable` - Creates a nullable column. Sets the default to `None`.
* `default` - Set a default value for the field. * `default` - Set a default value for the field.
* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). * `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`).
* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. * `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column.
Autoincrement is set by default on int primary keys. Autoincrement is set by default on int primary keys.
Available Model Fields: Available Model Fields:
* `orm.String(length)` * `orm.String(length)`
* `orm.Text()` * `orm.Text()`
* `orm.Boolean()` * `orm.Boolean()`
* `orm.Integer()` * `orm.Integer()`
* `orm.Float()` * `orm.Float()`
* `orm.Date()` * `orm.Date()`
* `orm.Time()` * `orm.Time()`
* `orm.DateTime()` * `orm.DateTime()`
* `orm.JSON()` * `orm.JSON()`
* `orm.BigInteger()` * `orm.BigInteger()`
* `orm.Decimal(lenght, precision)` * `orm.Decimal(lenght, precision)`
[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
[databases]: https://github.com/encode/databases [databases]: https://github.com/encode/databases

View File

@ -68,13 +68,15 @@ class QuerySet:
def table(self) -> sqlalchemy.Table: def table(self) -> sqlalchemy.Table:
return self.model_cls.__table__ return self.model_cls.__table__
def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]: @staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
return [ return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns for column in table.columns
] ]
def prefixed_table_name(self, alias: str, name: str) -> text: @staticmethod
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}") return text(f"{name} {alias}_{name}")
def on_clause( def on_clause(
@ -91,7 +93,7 @@ class QuerySet:
f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}'
) )
def build_join_parameters( def _build_join_parameters(
self, part: str, join_params: JoinParameters self, part: str, join_params: JoinParameters
) -> JoinParameters: ) -> JoinParameters:
model_cls = join_params.model_cls.__model_fields__[part].to model_cls = join_params.model_cls.__model_fields__[part].to
@ -137,12 +139,12 @@ class QuerySet:
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
@staticmethod @staticmethod
def field_is_a_foreign_key_and_no_circular_reference( def _field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str field: BaseField, field_name: str, rel_part: str
) -> bool: ) -> bool:
return isinstance(field, ForeignKey) and field_name not in rel_part return isinstance(field, ForeignKey) and field_name not in rel_part
def field_qualifies_to_deeper_search( def _field_qualifies_to_deeper_search(
self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str
) -> bool: ) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1]) prev_part_of_related = "__".join(rel_part.split("__")[:-1])
@ -155,7 +157,7 @@ class QuerySet:
or (partial_match and not already_checked) or (partial_match and not already_checked)
) or not nested ) or not nested
def extract_auto_required_relations( def _extract_auto_required_relations(
self, self,
join_params: JoinParameters, join_params: JoinParameters,
rel_part: str = "", rel_part: str = "",
@ -163,7 +165,7 @@ class QuerySet:
parent_virtual: bool = False, parent_virtual: bool = False,
) -> None: ) -> None:
for field_name, field in join_params.prev_model.__model_fields__.items(): for field_name, field in join_params.prev_model.__model_fields__.items():
if self.field_is_a_foreign_key_and_no_circular_reference( if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part field, field_name, rel_part
): ):
rel_part = field_name if not rel_part else rel_part + "__" + field_name rel_part = field_name if not rel_part else rel_part + "__" + field_name
@ -171,7 +173,7 @@ class QuerySet:
if rel_part not in self._select_related: if rel_part not in self._select_related:
self.auto_related.append("__".join(rel_part.split("__")[:-1])) self.auto_related.append("__".join(rel_part.split("__")[:-1]))
rel_part = "" rel_part = ""
elif self.field_qualifies_to_deeper_search( elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part field, parent_virtual, nested, rel_part
): ):
join_params = JoinParameters( join_params = JoinParameters(
@ -180,7 +182,7 @@ class QuerySet:
join_params.from_table, join_params.from_table,
join_params.prev_model, join_params.prev_model,
) )
self.extract_auto_required_relations( self._extract_auto_required_relations(
join_params=join_params, join_params=join_params,
rel_part=rel_part, rel_part=rel_part,
nested=True, nested=True,
@ -189,6 +191,41 @@ class QuerySet:
else: else:
rel_part = "" rel_part = ""
def _include_auto_related_models(self) -> None:
if self.auto_related:
new_joins = []
for join in self._select_related:
if not any([x.startswith(join) for x in self.auto_related]):
new_joins.append(join)
self._select_related = new_joins + self.auto_related
def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select:
if self.filter_clauses:
if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0]
else:
clause = sqlalchemy.sql.and_(*self.filter_clauses)
expr = expr.where(clause)
if self.limit_count:
expr = expr.limit(self.limit_count)
if self.query_offset:
expr = expr.offset(self.query_offset)
for order in self.order_bys:
expr = expr.order_by(order)
return expr
def _reset_query_parameters(self) -> None:
self.select_from = None
self.columns = None
self.order_bys = None
self.auto_related = []
self.used_aliases = []
def build_select_expression(self) -> sqlalchemy.sql.select: def build_select_expression(self) -> sqlalchemy.sql.select:
self.columns = list(self.table.columns) self.columns = list(self.table.columns)
self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")]
@ -207,14 +244,9 @@ class QuerySet:
start_params = JoinParameters( start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls self.model_cls, "", self.table.name, self.model_cls
) )
self.extract_auto_required_relations(start_params) self._extract_auto_required_relations(start_params)
if self.auto_related: self._include_auto_related_models()
new_joins = [] self._select_related.sort(key=lambda item: (-len(item), item))
for join in self._select_related:
if not any([x.startswith(join) for x in self.auto_related]):
new_joins.append(join)
self._select_related = new_joins + self.auto_related
self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related: for item in self._select_related:
join_parameters = JoinParameters( join_parameters = JoinParameters(
@ -222,34 +254,15 @@ class QuerySet:
) )
for part in item.split("__"): for part in item.split("__"):
join_parameters = self.build_join_parameters(part, join_parameters) join_parameters = self._build_join_parameters(part, join_parameters)
expr = sqlalchemy.sql.select(self.columns) expr = sqlalchemy.sql.select(self.columns)
expr = expr.select_from(self.select_from) expr = expr.select_from(self.select_from)
if self.filter_clauses: expr = self._apply_expression_modifiers(expr)
if len(self.filter_clauses) == 1:
clause = self.filter_clauses[0]
else:
clause = sqlalchemy.sql.and_(*self.filter_clauses)
expr = expr.where(clause)
if self.limit_count:
expr = expr.limit(self.limit_count)
if self.query_offset:
expr = expr.offset(self.query_offset)
for order in self.order_bys:
expr = expr.order_by(order)
# print(expr.compile(compile_kwargs={"literal_binds": True})) # print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters()
self.select_from = None
self.columns = None
self.order_bys = None
self.auto_related = []
self.used_aliases = []
return expr return expr
@ -298,7 +311,6 @@ class QuerySet:
model_cls = model_cls.__model_fields__[part].to model_cls = model_cls.__model_fields__[part].to
previous_table = current_table previous_table = current_table
# print(table_prefix)
table = model_cls.__table__ table = model_cls.__table__
column = model_cls.__table__.columns[field_name] column = model_cls.__table__.columns[field_name]

View File

@ -71,43 +71,47 @@ class RelationshipManager:
child, parent = parent, proxy(child) child, parent = parent, proxy(child)
else: else:
child = proxy(child) child = proxy(child)
parents_list = self._relations[
parent_name.lower().title() + "_" + child_name + "s" parent_relation_name = parent_name.lower().title() + "_" + child_name + "s"
].setdefault(parent_id, []) parents_list = self._relations[parent_relation_name].setdefault(parent_id, [])
self.append_related_model(parents_list, child) self.append_related_model(parents_list, child)
children_list = self._relations[
child_name.lower().title() + "_" + parent_name child_relation_name = child_name.lower().title() + "_" + parent_name
].setdefault(child_id, []) children_list = self._relations[child_relation_name].setdefault(child_id, [])
self.append_related_model(children_list, parent) self.append_related_model(children_list, parent)
def append_related_model( @staticmethod
self, relations_list: List["Model"], model: "Model" def append_related_model(relations_list: List["Model"], model: "Model") -> None:
) -> None: for relation_child in relations_list:
for x in relations_list:
try: try:
if x.__same__(model): if relation_child.__same__(model):
return return
except ReferenceError: except ReferenceError:
continue continue
relations_list.append(model) relations_list.append(model)
def contains(self, relations_key: str, object: "Model") -> bool: def contains(self, relations_key: str, instance: "Model") -> bool:
if relations_key in self._relations: if relations_key in self._relations:
return object._orm_id in self._relations[relations_key] return instance._orm_id in self._relations[relations_key]
return False return False
def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]: def get(
self, relations_key: str, instance: "Model"
) -> Union["Model", List["Model"]]:
if relations_key in self._relations: if relations_key in self._relations:
if object._orm_id in self._relations[relations_key]: if instance._orm_id in self._relations[relations_key]:
if self._relations[relations_key]["type"] == "primary": if self._relations[relations_key]["type"] == "primary":
return self._relations[relations_key][object._orm_id][0] return self._relations[relations_key][instance._orm_id][0]
return self._relations[relations_key][object._orm_id] return self._relations[relations_key][instance._orm_id]
def resolve_relation_join(self, from_table: str, to_table: str) -> str: def resolve_relation_join(self, from_table: str, to_table: str) -> str:
for k, v in self._relations.items(): for relation_name, relation in self._relations.items():
if v["source_table"] == from_table and v["target_table"] == to_table: if (
return self._relations[k]["table_alias"] relation["source_table"] == from_table
and relation["target_table"] == to_table
):
return self._relations[relation_name]["table_alias"]
return "" return ""
def __str__(self) -> str: # pragma no cover def __str__(self) -> str: # pragma no cover

View File

@ -14,3 +14,8 @@ flake8-bugbear
flake8-import-order flake8-import-order
flake8-bandit flake8-bandit
flake8-annotations flake8-annotations
flake8-builtins
flake8-variables-names
flake8-cognitive-complexity
flake8-functions
flake8-expression-complexity

View File

@ -67,16 +67,24 @@ def create_test_database():
metadata.drop_all(engine) metadata.drop_all(engine)
@pytest.mark.asyncio @pytest.fixture()
async def test_model_multiple_instances_of_same_table_in_schema(): async def init_relation():
async with database: department = await Department.objects.create(id=1, name='Math Department')
department = await Department.objects.create(id=1, name='Math Department') class1 = await SchoolClass.objects.create(name="Math", department=department)
class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign")
category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic")
category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Student.objects.create(name="Jane", category=category, schoolclass=class1) await Student.objects.create(name="Jack", category=category2, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
yield
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
@pytest.mark.asyncio
async def test_model_multiple_instances_of_same_table_in_schema(init_relation):
async with database:
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math' assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane' assert classes[0].students[0].name == 'Jane'
@ -92,18 +100,9 @@ async def test_model_multiple_instances_of_same_table_in_schema():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_right_tables_join(): async def test_right_tables_join(init_relation):
async with database: async with database:
department = await Department.objects.create(id=1, name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane'
assert classes[0].teachers[0].category.name == 'Domestic' assert classes[0].teachers[0].category.name == 'Domestic'
assert classes[0].students[0].category.name is None assert classes[0].students[0].category.name is None
@ -112,17 +111,9 @@ async def test_right_tables_join():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_reverse_related_objects(): async def test_multiple_reverse_related_objects(init_relation):
async with database: async with database:
department = await Department.objects.create(id=1, name='Math Department')
class1 = await SchoolClass.objects.create(name="Math", department=department)
category = await Category.objects.create(name="Foreign")
category2 = await Category.objects.create(name="Domestic")
await Student.objects.create(name="Jane", category=category, schoolclass=class1)
await Student.objects.create(name="Jack", category=category, schoolclass=class1)
await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1)
classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all()
assert classes[0].name == 'Math' assert classes[0].name == 'Math'
assert classes[0].students[0].name == 'Jane' assert classes[0].students[1].name == 'Jack'
assert classes[0].teachers[0].category.name == 'Domestic' assert classes[0].teachers[0].category.name == 'Domestic'