diff --git a/.coverage b/.coverage index 6b4bedb..0de5a48 100644 Binary files a/.coverage and b/.coverage differ diff --git a/README.md b/README.md index 5d99739..69758d9 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ The `async-orm` package is an async ORM for Python, with support for Postgres, MySQL, and SQLite. ORM is built with: -* [`SQLAlchemy core`][sqlalchemy-core] for query building. -* [`databases`][databases] for cross-database async support. -* [`pydantic`][pydantic] for data validation. + * [`SQLAlchemy core`][sqlalchemy-core] for query building. + * [`databases`][databases] for cross-database async support. + * [`pydantic`][pydantic] for data validation. Because ORM is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide database migrations. @@ -26,7 +26,7 @@ The goal was to create a simple orm that can be used directly with [`fastapi`][f Initial work was inspired by [`encode/orm`][encode/orm]. The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. -**aysn-orm is still under development: We recommend pinning any dependencies with `async-orm~=0.1`** +**aysn-orm is still under development:** We recommend pinning any dependencies with `aorm~=0.0.1` **Note**: Use `ipython` to try this from the console, since it supports `await`. @@ -155,33 +155,33 @@ assert len(tracks) == 1 The following keyword arguments are supported on all field types. -* `primary_key` -* `nullable` -* `default` -* `server_default` -* `index` -* `unique` + * `primary_key` + * `nullable` + * `default` + * `server_default` + * `index` + * `unique` All fields are required unless one of the following is set: -* `nullable` - Creates a nullable column. Sets the default to `None`. -* `default` - Set a default value for the field. -* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). -* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. + * `nullable` - Creates a nullable column. Sets the default to `None`. + * `default` - Set a default value for the field. + * `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). + * `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. Autoincrement is set by default on int primary keys. Available Model Fields: -* `orm.String(length)` -* `orm.Text()` -* `orm.Boolean()` -* `orm.Integer()` -* `orm.Float()` -* `orm.Date()` -* `orm.Time()` -* `orm.DateTime()` -* `orm.JSON()` -* `orm.BigInteger()` -* `orm.Decimal(lenght, precision)` + * `orm.String(length)` + * `orm.Text()` + * `orm.Boolean()` + * `orm.Integer()` + * `orm.Float()` + * `orm.Date()` + * `orm.Time()` + * `orm.DateTime()` + * `orm.JSON()` + * `orm.BigInteger()` + * `orm.Decimal(lenght, precision)` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [databases]: https://github.com/encode/databases diff --git a/orm/queryset.py b/orm/queryset.py index 39a187c..d66d50a 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -68,13 +68,15 @@ class QuerySet: def table(self) -> sqlalchemy.Table: return self.model_cls.__table__ - def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]: + @staticmethod + def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: return [ text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") for column in table.columns ] - def prefixed_table_name(self, alias: str, name: str) -> text: + @staticmethod + def prefixed_table_name(alias: str, name: str) -> text: return text(f"{name} {alias}_{name}") def on_clause( @@ -91,7 +93,7 @@ class QuerySet: f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' ) - def build_join_parameters( + def _build_join_parameters( self, part: str, join_params: JoinParameters ) -> JoinParameters: model_cls = join_params.model_cls.__model_fields__[part].to @@ -137,12 +139,12 @@ class QuerySet: return JoinParameters(prev_model, previous_alias, from_table, model_cls) @staticmethod - def field_is_a_foreign_key_and_no_circular_reference( + def _field_is_a_foreign_key_and_no_circular_reference( field: BaseField, field_name: str, rel_part: str ) -> bool: return isinstance(field, ForeignKey) and field_name not in rel_part - def field_qualifies_to_deeper_search( + def _field_qualifies_to_deeper_search( self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) @@ -155,7 +157,7 @@ class QuerySet: or (partial_match and not already_checked) ) or not nested - def extract_auto_required_relations( + def _extract_auto_required_relations( self, join_params: JoinParameters, rel_part: str = "", @@ -163,7 +165,7 @@ class QuerySet: parent_virtual: bool = False, ) -> None: for field_name, field in join_params.prev_model.__model_fields__.items(): - if self.field_is_a_foreign_key_and_no_circular_reference( + if self._field_is_a_foreign_key_and_no_circular_reference( field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name @@ -171,7 +173,7 @@ class QuerySet: if rel_part not in self._select_related: self.auto_related.append("__".join(rel_part.split("__")[:-1])) rel_part = "" - elif self.field_qualifies_to_deeper_search( + elif self._field_qualifies_to_deeper_search( field, parent_virtual, nested, rel_part ): join_params = JoinParameters( @@ -180,7 +182,7 @@ class QuerySet: join_params.from_table, join_params.prev_model, ) - self.extract_auto_required_relations( + self._extract_auto_required_relations( join_params=join_params, rel_part=rel_part, nested=True, @@ -189,6 +191,41 @@ class QuerySet: else: rel_part = "" + def _include_auto_related_models(self) -> None: + if self.auto_related: + new_joins = [] + for join in self._select_related: + if not any([x.startswith(join) for x in self.auto_related]): + new_joins.append(join) + self._select_related = new_joins + self.auto_related + + def _apply_expression_modifiers( + self, expr: sqlalchemy.sql.select + ) -> sqlalchemy.sql.select: + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + for order in self.order_bys: + expr = expr.order_by(order) + return expr + + def _reset_query_parameters(self) -> None: + self.select_from = None + self.columns = None + self.order_bys = None + self.auto_related = [] + self.used_aliases = [] + def build_select_expression(self) -> sqlalchemy.sql.select: self.columns = list(self.table.columns) self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] @@ -207,14 +244,9 @@ class QuerySet: start_params = JoinParameters( self.model_cls, "", self.table.name, self.model_cls ) - self.extract_auto_required_relations(start_params) - if self.auto_related: - new_joins = [] - for join in self._select_related: - if not any([x.startswith(join) for x in self.auto_related]): - new_joins.append(join) - self._select_related = new_joins + self.auto_related - self._select_related.sort(key=lambda item: (-len(item), item)) + self._extract_auto_required_relations(start_params) + self._include_auto_related_models() + self._select_related.sort(key=lambda item: (-len(item), item)) for item in self._select_related: join_parameters = JoinParameters( @@ -222,34 +254,15 @@ class QuerySet: ) for part in item.split("__"): - join_parameters = self.build_join_parameters(part, join_parameters) + join_parameters = self._build_join_parameters(part, join_parameters) expr = sqlalchemy.sql.select(self.columns) expr = expr.select_from(self.select_from) - if self.filter_clauses: - if len(self.filter_clauses) == 1: - clause = self.filter_clauses[0] - else: - clause = sqlalchemy.sql.and_(*self.filter_clauses) - expr = expr.where(clause) - - if self.limit_count: - expr = expr.limit(self.limit_count) - - if self.query_offset: - expr = expr.offset(self.query_offset) - - for order in self.order_bys: - expr = expr.order_by(order) + expr = self._apply_expression_modifiers(expr) # print(expr.compile(compile_kwargs={"literal_binds": True})) - - self.select_from = None - self.columns = None - self.order_bys = None - self.auto_related = [] - self.used_aliases = [] + self._reset_query_parameters() return expr @@ -298,7 +311,6 @@ class QuerySet: model_cls = model_cls.__model_fields__[part].to previous_table = current_table - # print(table_prefix) table = model_cls.__table__ column = model_cls.__table__.columns[field_name] diff --git a/orm/relations.py b/orm/relations.py index b5741e1..3232c8c 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -71,43 +71,47 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - parents_list = self._relations[ - parent_name.lower().title() + "_" + child_name + "s" - ].setdefault(parent_id, []) + + parent_relation_name = parent_name.lower().title() + "_" + child_name + "s" + parents_list = self._relations[parent_relation_name].setdefault(parent_id, []) self.append_related_model(parents_list, child) - children_list = self._relations[ - child_name.lower().title() + "_" + parent_name - ].setdefault(child_id, []) + + child_relation_name = child_name.lower().title() + "_" + parent_name + children_list = self._relations[child_relation_name].setdefault(child_id, []) self.append_related_model(children_list, parent) - def append_related_model( - self, relations_list: List["Model"], model: "Model" - ) -> None: - for x in relations_list: + @staticmethod + def append_related_model(relations_list: List["Model"], model: "Model") -> None: + for relation_child in relations_list: try: - if x.__same__(model): + if relation_child.__same__(model): return except ReferenceError: continue relations_list.append(model) - def contains(self, relations_key: str, object: "Model") -> bool: + def contains(self, relations_key: str, instance: "Model") -> bool: if relations_key in self._relations: - return object._orm_id in self._relations[relations_key] + return instance._orm_id in self._relations[relations_key] return False - def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]: + def get( + self, relations_key: str, instance: "Model" + ) -> Union["Model", List["Model"]]: if relations_key in self._relations: - if object._orm_id in self._relations[relations_key]: + if instance._orm_id in self._relations[relations_key]: if self._relations[relations_key]["type"] == "primary": - return self._relations[relations_key][object._orm_id][0] - return self._relations[relations_key][object._orm_id] + return self._relations[relations_key][instance._orm_id][0] + return self._relations[relations_key][instance._orm_id] def resolve_relation_join(self, from_table: str, to_table: str) -> str: - for k, v in self._relations.items(): - if v["source_table"] == from_table and v["target_table"] == to_table: - return self._relations[k]["table_alias"] + for relation_name, relation in self._relations.items(): + if ( + relation["source_table"] == from_table + and relation["target_table"] == to_table + ): + return self._relations[relation_name]["table_alias"] return "" def __str__(self) -> str: # pragma no cover diff --git a/requirements.txt b/requirements.txt index fde4a73..807e704 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,9 @@ flake8-black flake8-bugbear flake8-import-order flake8-bandit -flake8-annotations \ No newline at end of file +flake8-annotations +flake8-builtins +flake8-variables-names +flake8-cognitive-complexity +flake8-functions +flake8-expression-complexity \ No newline at end of file diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index ca4e7b7..ea85409 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -67,16 +67,24 @@ def create_test_database(): metadata.drop_all(engine) -@pytest.mark.asyncio -async def test_model_multiple_instances_of_same_table_in_schema(): - async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) +@pytest.fixture() +async def init_relation(): + department = await Department.objects.create(id=1, name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Jack", category=category2, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + yield + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + +@pytest.mark.asyncio +async def test_model_multiple_instances_of_same_table_in_schema(init_relation): + async with database: classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' assert classes[0].students[0].name == 'Jane' @@ -92,18 +100,9 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio -async def test_right_tables_join(): +async def test_right_tables_join(init_relation): async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].name == 'Math' - assert classes[0].students[0].name == 'Jane' assert classes[0].teachers[0].category.name == 'Domestic' assert classes[0].students[0].category.name is None @@ -112,17 +111,9 @@ async def test_right_tables_join(): @pytest.mark.asyncio -async def test_multiple_reverse_related_objects(): +async def test_multiple_reverse_related_objects(init_relation): async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Student.objects.create(name="Jack", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' - assert classes[0].students[0].name == 'Jane' + assert classes[0].students[1].name == 'Jack' assert classes[0].teachers[0].category.name == 'Domestic'