some further cleanup and optim

This commit is contained in:
collerek
2020-11-12 08:56:16 +01:00
parent e743286008
commit d8391851fa
7 changed files with 56 additions and 48 deletions

View File

@ -5,6 +5,7 @@ import sqlalchemy
from pydantic import Field, typing from pydantic import Field, typing
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
import ormar # noqa I101
from ormar import ModelDefinitionError # noqa I101 from ormar import ModelDefinitionError # noqa I101
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -34,6 +35,10 @@ class BaseField(FieldInfo):
default: Any default: Any
server_default: Any server_default: Any
@classmethod
def is_valid_uni_relation(cls) -> bool:
return not issubclass(cls, ormar.fields.ManyToManyField) and not cls.virtual
@classmethod @classmethod
def get_alias(cls) -> str: def get_alias(cls) -> str:
return cls.alias if cls.alias else cls.name return cls.alias if cls.alias else cls.name

View File

@ -17,10 +17,10 @@ from ormar.exceptions import RelationshipInstanceError
try: try:
import orjson as json import orjson as json
except ImportError: # pragma: nocover except ImportError: # pragma: nocover
import json # type: ignore import json # type: ignore
import ormar import ormar # noqa: I100
from ormar.fields import BaseField, ManyToManyField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta from ormar.models.metaclass import ModelMeta
@ -111,29 +111,22 @@ class ModelTableProxy:
@classmethod @classmethod
def _extract_db_related_names(cls) -> Set: def _extract_db_related_names(cls) -> Set:
related_names = set() related_names = cls.extract_related_names()
for name, field in cls.Meta.model_fields.items(): related_names = {
if ( name
inspect.isclass(field) for name in related_names
and issubclass(field, ForeignKeyField) if cls.Meta.model_fields[name].is_valid_uni_relation()
and not issubclass(field, ManyToManyField) }
and not field.virtual
):
related_names.add(name)
return related_names return related_names
@classmethod @classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set: def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested: if nested:
return cls.extract_related_names() return cls.extract_related_names()
related_names = set() related_names = cls.extract_related_names()
for name, field in cls.Meta.model_fields.items(): related_names = {
if ( name for name in related_names if cls.Meta.model_fields[name].nullable
inspect.isclass(field) }
and issubclass(field, ForeignKeyField)
and field.nullable
):
related_names.add(name)
return related_names return related_names
def _extract_model_db_fields(self) -> Dict: def _extract_model_db_fields(self) -> Dict:
@ -151,8 +144,8 @@ class ModelTableProxy:
@staticmethod @staticmethod
def resolve_relation_name( # noqa CCR001 def resolve_relation_name( # noqa CCR001
item: Union["NewBaseModel", Type["NewBaseModel"]], item: Union["NewBaseModel", Type["NewBaseModel"]],
related: Union["NewBaseModel", Type["NewBaseModel"]] related: Union["NewBaseModel", Type["NewBaseModel"]],
) -> str: ) -> str:
for name, field in item.Meta.model_fields.items(): for name, field in item.Meta.model_fields.items():
if issubclass(field, ForeignKeyField): if issubclass(field, ForeignKeyField):
@ -168,7 +161,7 @@ class ModelTableProxy:
@staticmethod @staticmethod
def resolve_relation_field( def resolve_relation_field(
item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]] item: Union["Model", Type["Model"]], related: Union["Model", Type["Model"]]
) -> Type[BaseField]: ) -> Type[BaseField]:
name = ModelTableProxy.resolve_relation_name(item, related) name = ModelTableProxy.resolve_relation_name(item, related)
to_field = item.Meta.model_fields.get(name) to_field = item.Meta.model_fields.get(name)
@ -215,12 +208,12 @@ class ModelTableProxy:
for field in one.Meta.model_fields.keys(): for field in one.Meta.model_fields.keys():
current_field = getattr(one, field) current_field = getattr(one, field)
if isinstance(current_field, list) and not isinstance( if isinstance(current_field, list) and not isinstance(
current_field, ormar.Model current_field, ormar.Model
): ):
setattr(other, field, current_field + getattr(other, field)) setattr(other, field, current_field + getattr(other, field))
elif ( elif (
isinstance(current_field, ormar.Model) isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk and current_field.pk == getattr(other, field).pk
): ):
setattr( setattr(
other, other,
@ -231,7 +224,7 @@ class ModelTableProxy:
@staticmethod @staticmethod
def _populate_pk_column( def _populate_pk_column(
model: Type["Model"], columns: List[str], use_alias: bool = False, model: Type["Model"], columns: List[str], use_alias: bool = False,
) -> List[str]: ) -> List[str]:
pk_alias = ( pk_alias = (
model.get_column_alias(model.Meta.pkname) model.get_column_alias(model.Meta.pkname)
@ -244,10 +237,10 @@ class ModelTableProxy:
@staticmethod @staticmethod
def own_table_columns( def own_table_columns(
model: Type["Model"], model: Type["Model"],
fields: Optional[Union[Set, Dict]], fields: Optional[Union[Set, Dict]],
exclude_fields: Optional[Union[Set, Dict]], exclude_fields: Optional[Union[Set, Dict]],
use_alias: bool = False, use_alias: bool = False,
) -> List[str]: ) -> List[str]:
columns = [ columns = [
model.get_column_name_from_alias(col.name) if not use_alias else col.name model.get_column_name_from_alias(col.name) if not use_alias else col.name

View File

@ -9,7 +9,8 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Set, TYPE_CHECKING, Set,
TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -43,7 +44,14 @@ if TYPE_CHECKING: # pragma no cover
class NewBaseModel( class NewBaseModel(
pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass pydantic.BaseModel, ModelTableProxy, Excludable, metaclass=ModelMetaclass
): ):
__slots__ = ("_orm_id", "_orm_saved", "_orm", "_related_names", "_related_names_hash", "_props") __slots__ = (
"_orm_id",
"_orm_saved",
"_orm",
"_related_names",
"_related_names_hash",
"_props",
)
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, Type[BaseField]] __model_fields__: Dict[str, Type[BaseField]]
@ -130,7 +138,14 @@ class NewBaseModel(
super().__setattr__(name, value) super().__setattr__(name, value)
def __getattribute__(self, item: str) -> Any: def __getattribute__(self, item: str) -> Any:
if item in ("_orm_id", "_orm_saved", "_orm", "__fields__", "_related_names", "_props"): if item in (
"_orm_id",
"_orm_saved",
"_orm",
"__fields__",
"_related_names",
"_props",
):
return object.__getattribute__(self, item) return object.__getattribute__(self, item)
if item == "pk": if item == "pk":
return self.__dict__.get(self.Meta.pkname, None) return self.__dict__.get(self.Meta.pkname, None)

View File

@ -1,4 +1,3 @@
import copy
from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Type, Union
import databases import databases
@ -175,7 +174,7 @@ class QuerySet:
if isinstance(columns, str): if isinstance(columns, str):
columns = [columns] columns = [columns]
current_excluded = copy.deepcopy(self._exclude_columns) current_excluded = self._exclude_columns
if not isinstance(columns, dict): if not isinstance(columns, dict):
current_excluded = update_dict_from_list(current_excluded, columns) current_excluded = update_dict_from_list(current_excluded, columns)
else: else:
@ -197,7 +196,7 @@ class QuerySet:
if isinstance(columns, str): if isinstance(columns, str):
columns = [columns] columns = [columns]
current_included = copy.deepcopy(self._exclude_columns) current_included = self._exclude_columns
if not isinstance(columns, dict): if not isinstance(columns, dict):
current_included = update_dict_from_list(current_included, columns) current_included = update_dict_from_list(current_included, columns)
else: else:

View File

@ -1,9 +1,7 @@
from ormar.relations.alias_manager import AliasManager from ormar.relations.alias_manager import AliasManager
from ormar.relations.relation import Relation, RelationType from ormar.relations.relation import Relation, RelationType
from ormar.relations.relation_manager import RelationsManager from ormar.relations.relation_manager import RelationsManager
from ormar.relations.utils import ( from ormar.relations.utils import get_relations_sides_and_names
get_relations_sides_and_names,
)
__all__ = [ __all__ = [
"AliasManager", "AliasManager",

View File

@ -5,9 +5,7 @@ from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.relations.relation import Relation, RelationType from ormar.relations.relation import Relation, RelationType
from ormar.relations.utils import ( from ormar.relations.utils import get_relations_sides_and_names
get_relations_sides_and_names,
)
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model

View File

@ -9,11 +9,11 @@ if TYPE_CHECKING: # pragma no cover
def get_relations_sides_and_names( def get_relations_sides_and_names(
to_field: Type[BaseField], to_field: Type[BaseField],
parent: "Model", parent: "Model",
child: "Model", child: "Model",
child_name: str, child_name: str,
virtual: bool, virtual: bool,
) -> Tuple["Model", "Model", str, str]: ) -> Tuple["Model", "Model", str, str]:
to_name = to_field.name to_name = to_field.name
if issubclass(to_field, ManyToManyField): if issubclass(to_field, ManyToManyField):