diff --git a/.coverage b/.coverage index 80aea2e..52c345a 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.flake8 b/.flake8 index af173f2..b804c49 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,5 @@ ignore = ANN101, ANN102, W503, S101 max-complexity = 8 max-line-length = 88 +import-order-style = pycharm exclude = p38venv,.pytest_cache diff --git a/orm/__init__.py b/orm/__init__.py index 773ab78..bcc1d9c 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -33,4 +33,5 @@ __all__ = [ "ModelNotSet", "MultipleMatches", "NoMatch", + "ForeignKey", ] diff --git a/orm/fields/__init__.py b/orm/fields/__init__.py index 75a1b7d..6355c38 100644 --- a/orm/fields/__init__.py +++ b/orm/fields/__init__.py @@ -1,18 +1,18 @@ +from orm.fields.base import BaseField +from orm.fields.foreign_key import ForeignKey from orm.fields.model_fields import ( BigInteger, Boolean, Date, DateTime, Decimal, - String, - Integer, - Text, Float, - Time, + Integer, JSON, + String, + Text, + Time, ) -from orm.fields.foreign_key import ForeignKey -from orm.fields.base import BaseField __all__ = [ "Decimal", diff --git a/orm/fields/base.py b/orm/fields/base.py index 32c0f13..9d321f2 100644 --- a/orm/fields/base.py +++ b/orm/fields/base.py @@ -1,8 +1,11 @@ -from typing import Type, Any, Dict, Optional, List +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type import sqlalchemy -from orm import ModelDefinitionError +from orm import ModelDefinitionError # noqa I101 + +if TYPE_CHECKING: # pragma no cover + from orm.models import Model class RequiredParams: diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index a5e65dd..506a9ac 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -1,9 +1,9 @@ -from typing import Type, List, Any, Union, TYPE_CHECKING, Optional +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy from pydantic import BaseModel -import orm +import orm # noqa I101 from orm.exceptions import RelationshipInstanceError from orm.fields.base import BaseField @@ -25,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -50,7 +50,7 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: @@ -80,9 +80,7 @@ class ForeignKey(BaseField): model = create_dummy_instance(fk=self.to, pk=value) model._orm_relationship_manager.add_relation( - model, - child, - virtual=self.virtual, + model, child, virtual=self.virtual, ) return model diff --git a/orm/fields/model_fields.py b/orm/fields/model_fields.py index 3ddae06..813f190 100644 --- a/orm/fields/model_fields.py +++ b/orm/fields/model_fields.py @@ -4,7 +4,7 @@ import decimal import sqlalchemy from pydantic import Json -from orm.fields.base import BaseField, RequiredParams +from orm.fields.base import BaseField, RequiredParams # noqa I101 @RequiredParams("length") diff --git a/orm/models.py b/orm/models.py index 9fde6e7..22750eb 100644 --- a/orm/models.py +++ b/orm/models.py @@ -6,18 +6,16 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from typing import Callable, Dict, Set import databases - -import orm.queryset as qry -from orm.exceptions import ModelDefinitionError -from orm import ForeignKey -from orm.fields.base import BaseField -from orm.relations import RelationshipManager - import pydantic +import sqlalchemy from pydantic import BaseConfig, BaseModel, create_model from pydantic.fields import ModelField -import sqlalchemy +import orm.queryset as qry # noqa I100 +from orm import ForeignKey +from orm.exceptions import ModelDefinitionError +from orm.fields.base import BaseField +from orm.relations import RelationshipManager relationship_manager = RelationshipManager() diff --git a/orm/queryset.py b/orm/queryset.py index ea6c264..5033f13 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -11,16 +11,15 @@ from typing import ( ) import databases +import sqlalchemy +from sqlalchemy import text -import orm +import orm # noqa I100 import orm.fields.foreign_key from orm import ForeignKey from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError from orm.fields.base import BaseField -import sqlalchemy -from sqlalchemy import text - if TYPE_CHECKING: # pragma no cover from orm.models import Model diff --git a/orm/relations.py b/orm/relations.py index fa0dbcf..7cf5ecf 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -16,7 +16,7 @@ def get_table_alias() -> str: def get_relation_config( - relation_type: str, table_name: str, field: ForeignKey + relation_type: str, table_name: str, field: ForeignKey ) -> Dict[str, str]: alias = get_table_alias() config = { @@ -37,7 +37,7 @@ class RelationshipManager: self._relations = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str ) -> None: if relations_key not in self._relations: self._relations[relations_key] = get_relation_config( @@ -56,10 +56,7 @@ class RelationshipManager: del self._relations[rel_type][model._orm_id] def add_relation( - self, - parent: "FakePydantic", - child: "FakePydantic", - virtual: bool = False, + self, parent: "FakePydantic", child: "FakePydantic", virtual: bool = False, ) -> None: parent_id = parent._orm_id child_id = child._orm_id @@ -97,7 +94,7 @@ class RelationshipManager: return False def get( - self, relations_key: str, instance: "FakePydantic" + self, relations_key: str, instance: "FakePydantic" ) -> Union["Model", List["Model"]]: if relations_key in self._relations: if instance._orm_id in self._relations[relations_key]: @@ -108,8 +105,8 @@ class RelationshipManager: def resolve_relation_join(self, from_table: str, to_table: str) -> str: for relation_name, relation in self._relations.items(): if ( - relation["source_table"] == from_table - and relation["target_table"] == to_table + relation["source_table"] == from_table + and relation["target_table"] == to_table ): return self._relations[relation_name]["table_alias"] return ""