refactor and cleanup

This commit is contained in:
collerek
2020-08-23 17:50:40 +02:00
parent f73a97e560
commit 0f72bf36eb
13 changed files with 248 additions and 216 deletions

BIN
.coverage

Binary file not shown.

View File

@ -31,12 +31,11 @@ Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to p
database migrations. database migrations.
The goal was to create a simple ORM that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. The goal was to create a simple ORM that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic.
Initial work was inspired by [`encode/orm`][encode/orm]. Initial work was inspired by [`encode/orm`][encode/orm], later I found `ormantic` and used it as a further inspiration.
The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks.
To avoid too high coupling with pydantic and sqlalchemy ormar uses them by **composition** rather than by **inheritance**.
**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.1.1` **ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.2.0`
**Note**: Use `ipython` to try this from the console, since it supports `await`. **Note**: Use `ipython` to try this from the console, since it supports `await`.
@ -50,14 +49,15 @@ metadata = sqlalchemy.MetaData()
class Note(ormar.Model): class Note(ormar.Model):
__tablename__ = "notes" class Meta:
__database__ = database tablename = "notes"
__metadata__ = metadata database = database
metadata = metadata
# primary keys of type int by dafault are set to autoincrement # primary keys of type int by dafault are set to autoincrement
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
text = ormar.String(length=100) text: ormar.String(length=100)
completed = ormar.Boolean(default=False) completed: ormar.Boolean(default=False)
# Create the database # Create the database
engine = sqlalchemy.create_engine(str(database.url)) engine = sqlalchemy.create_engine(str(database.url))
@ -103,23 +103,25 @@ metadata = sqlalchemy.MetaData()
class Album(ormar.Model): class Album(ormar.Model):
__tablename__ = "album" class Meta:
__metadata__ = metadata tablename = "album"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
name = ormar.String(length=100) name: ormar.String(length=100)
class Track(ormar.Model): class Track(ormar.Model):
__tablename__ = "track" class Meta:
__metadata__ = metadata tablename = "track"
__database__ = database metadata = metadata
database = database
id = ormar.Integer(primary_key=True) id: ormar.Integer(primary_key=True)
album = ormar.ForeignKey(Album) album: ormar.ForeignKey(Album)
title = ormar.String(length=100) title: ormar.String(length=100)
position = ormar.Integer() position: ormar.Integer()
# Create some records to work with. # Create some records to work with.

View File

@ -15,7 +15,7 @@ from ormar.fields import (
) )
from ormar.models import Model from ormar.models import Model
__version__ = "0.1.3" __version__ = "0.2.0"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, TYPE_CHECKING from typing import Any, List, Optional, TYPE_CHECKING, Union
import sqlalchemy import sqlalchemy
from pydantic import Field from pydantic import Field
@ -7,6 +7,7 @@ from ormar import ModelDefinitionError # noqa I101
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model
from ormar.models import FakePydantic
class BaseField: class BaseField:
@ -62,5 +63,7 @@ class BaseField:
) )
@classmethod @classmethod
def expand_relationship(cls, value: Any, child: "Model") -> Any: def expand_relationship(
cls, value: Any, child: Union["Model", "FakePydantic"]
) -> Any:
return value return value

View File

@ -1,4 +1,3 @@
import inspect
import json import json
import uuid import uuid
from typing import ( from typing import (
@ -8,7 +7,6 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Set,
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
@ -22,8 +20,8 @@ from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta, ModelMetaclass from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.models.modelproxy import ModelTableProxy
from ormar.relations import RelationshipManager from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -35,12 +33,8 @@ if TYPE_CHECKING: # pragma no cover
MappingIntStrAny = Mapping[IntStr, Any] MappingIntStrAny = Mapping[IntStr, Any]
class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): class FakePydantic(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass):
# FakePydantic inherits from list in order to be treated as
# request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts
__slots__ = ("_orm_id", "_orm_saved") __slots__ = ("_orm_id", "_orm_saved")
__abstract__ = True
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]] __model_fields__: Dict[str, TypeVar[BaseField]]
@ -77,9 +71,6 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
object.__setattr__(self, "__dict__", values) object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set) object.__setattr__(self, "__fields_set__", fields_set)
# super().__init__(**kwargs)
# self.values = self.__pydantic_model__(**kwargs)
def __del__(self) -> None: def __del__(self) -> None:
self.Meta._orm_relationship_manager.deregister(self) self.Meta._orm_relationship_manager.deregister(self)
@ -175,9 +166,7 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
if self.Meta.model_fields[field].virtual and nested: if self.Meta.model_fields[field].virtual and nested:
continue continue
if isinstance(nested_model, list) and not isinstance( if isinstance(nested_model, list):
nested_model, ormar.Model
):
dict_instance[field] = [x.dict(nested=True) for x in nested_model] dict_instance[field] = [x.dict(nested=True) for x in nested_model]
elif nested_model is not None: elif nested_model is not None:
dict_instance[field] = nested_model.dict(nested=True) dict_instance[field] = nested_model.dict(nested=True)
@ -206,70 +195,3 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def _is_conversion_to_json_needed(self, column_name: str) -> bool: def _is_conversion_to_json_needed(self, column_name: str) -> bool:
return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json
def _extract_own_model_fields(self) -> Dict:
related_names = self._extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@classmethod
def _extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
related_names.add(name)
return related_names
@classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested:
return cls._extract_related_names()
related_names = set()
for name, field in cls.Meta.model_fields.items():
if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and field.nullable
):
related_names.add(name)
return related_names
def _extract_model_db_fields(self) -> Dict:
self_fields = self._extract_own_model_fields()
self_fields = {
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
}
for field in self._extract_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), target_pk_name)
return self_fields
@classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk:
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else:
merged_rows.append(model)
return merged_rows
@classmethod
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.Meta.model_fields.keys():
current_field = getattr(one, field)
if isinstance(current_field, list) and not isinstance(
current_field, ormar.Model
):
setattr(other, field, current_field + getattr(other, field))
elif (
isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk
):
setattr(
other,
field,
cls.merge_two_instances(current_field, getattr(other, field)),
)
return other

View File

@ -161,5 +161,4 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
new_model.Meta._orm_relationship_manager = relationship_manager new_model.Meta._orm_relationship_manager = relationship_manager
new_model.objects = QuerySet(new_model) new_model.objects = QuerySet(new_model)
# breakpoint()
return new_model return new_model

View File

@ -9,8 +9,6 @@ from ormar.models import FakePydantic # noqa I100
class Model(FakePydantic): class Model(FakePydantic):
__abstract__ = False __abstract__ = False
# objects = ormar.queryset.QuerySet()
@classmethod @classmethod
def from_row( def from_row(
cls, cls,

View File

@ -0,0 +1,98 @@
import copy
import inspect
from typing import List, Set, TYPE_CHECKING
import ormar
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMeta
if TYPE_CHECKING: # pragma no cover
from ormar import Model
class ModelTableProxy:
if TYPE_CHECKING: # pragma no cover
Meta: ModelMeta
def dict(): # noqa A003
raise NotImplementedError # pragma no cover
def _extract_own_model_fields(self) -> dict:
related_names = self._extract_related_names()
self_fields = {k: v for k, v in self.dict().items() if k not in related_names}
return self_fields
@classmethod
def substitute_models_with_pks(cls, model_dict: dict) -> dict:
model_dict = copy.deepcopy(model_dict)
for field in cls._extract_related_names():
if field in model_dict and model_dict.get(field) is not None:
target_field = cls.Meta.model_fields[field]
target_pkname = target_field.to.Meta.pkname
if isinstance(model_dict.get(field), ormar.Model):
model_dict[field] = getattr(model_dict.get(field), target_pkname)
else:
model_dict[field] = model_dict.get(field).get(target_pkname)
return model_dict
@classmethod
def _extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
related_names.add(name)
return related_names
@classmethod
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested:
return cls._extract_related_names()
related_names = set()
for name, field in cls.Meta.model_fields.items():
if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and field.nullable
):
related_names.add(name)
return related_names
def _extract_model_db_fields(self) -> dict:
self_fields = self._extract_own_model_fields()
self_fields = {
k: v for k, v in self_fields.items() if k in self.Meta.table.columns
}
for field in self._extract_related_names():
target_pk_name = self.Meta.model_fields[field].to.Meta.pkname
if getattr(self, field) is not None:
self_fields[field] = getattr(getattr(self, field), target_pk_name)
return self_fields
@classmethod
def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]:
merged_rows = []
for index, model in enumerate(result_rows):
if index > 0 and model.pk == result_rows[index - 1].pk:
result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1])
else:
merged_rows.append(model)
return merged_rows
@classmethod
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.Meta.model_fields.keys():
current_field = getattr(one, field)
if isinstance(current_field, list) and not isinstance(
current_field, ormar.Model
):
setattr(other, field, current_field + getattr(other, field))
elif (
isinstance(current_field, ormar.Model)
and current_field.pk == getattr(other, field).pk
):
setattr(
other,
field,
cls.merge_two_instances(current_field, getattr(other, field)),
)
return other

View File

@ -4,8 +4,9 @@ import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
import ormar # noqa I100 import ormar # noqa I100
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.queryset.relationship_crawler import RelationshipCrawler
from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar import Model from ormar import Model
@ -36,34 +37,28 @@ class Query:
self.model_cls = model_cls self.model_cls = model_cls
self.table = self.model_cls.Meta.table self.table = self.model_cls.Meta.table
self.auto_related = []
self.used_aliases = [] self.used_aliases = []
self.already_checked = []
self.select_from = None self.select_from = None
self.columns = None self.columns = None
self.order_bys = None self.order_bys = None
@property
def relation_manager(self) -> RelationshipManager:
return self.model_cls.Meta._orm_relationship_manager
def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]:
self.columns = list(self.table.columns) self.columns = list(self.table.columns)
self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
self.select_from = self.table self.select_from = self.table
# for key in self.model_cls.Meta.model_fields:
# if (
# not self.model_cls.Meta.model_fields[key].nullable
# and isinstance(
# self.model_cls.Meta.model_fields[key], ForeignKeyField,
# )
# and key not in self._select_related
# ):
# self._select_related = [key] + self._select_related
start_params = JoinParameters( start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls self.model_cls, "", self.table.name, self.model_cls
) )
self._extract_auto_required_relations(prev_model=start_params.prev_model)
self._include_auto_related_models() self._select_related = RelationshipCrawler().discover_relations(
self._select_related, prev_model=start_params.prev_model
)
self._select_related.sort(key=lambda item: (-len(item), item)) self._select_related.sort(key=lambda item: (-len(item), item))
for item in self._select_related: for item in self._select_related:
@ -84,38 +79,6 @@ class Query:
return expr, self._select_related return expr, self._select_related
@staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
@staticmethod
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
@staticmethod
def _field_is_a_foreign_key_and_no_circular_reference(
field: Type[BaseField], field_name: str, rel_part: str
) -> bool:
return issubclass(field, ForeignKeyField) and field_name not in rel_part
def _field_qualifies_to_deeper_search(
self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any(
[x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]
)
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def on_clause( def on_clause(
self, previous_alias: str, alias: str, from_clause: str, to_clause: str, self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text: ) -> text:
@ -154,12 +117,14 @@ class Query:
from_clause=f"{join_params.from_table}.{from_key}", from_clause=f"{join_params.from_table}.{from_key}",
to_clause=f"{to_table}.{to_key}", to_clause=f"{to_table}.{to_key}",
) )
target_table = self.prefixed_table_name(alias, to_table) target_table = self.relation_manager.prefixed_table_name(alias, to_table)
self.select_from = sqlalchemy.sql.outerjoin( self.select_from = sqlalchemy.sql.outerjoin(
self.select_from, target_table, on_clause self.select_from, target_table, on_clause
) )
self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}"))
self.columns.extend(self.prefixed_columns(alias, model_cls.Meta.table)) self.columns.extend(
self.relation_manager.prefixed_columns(alias, model_cls.Meta.table)
)
self.used_aliases.append(alias) self.used_aliases.append(alias)
previous_alias = alias previous_alias = alias
@ -167,49 +132,6 @@ class Query:
prev_model = model_cls prev_model = model_cls
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _extract_auto_required_relations(
self,
prev_model: Type["Model"],
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
) -> None:
for field_name, field in prev_model.Meta.model_fields.items():
if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part
):
rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable:
if rel_part not in self._select_related:
new_related = (
"__".join(rel_part.split("__")[:-1])
if len(rel_part.split("__")) > 1
else rel_part
)
self.auto_related.append(new_related)
rel_part = ""
elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part
):
self._extract_auto_required_relations(
prev_model=field.to,
rel_part=rel_part,
nested=True,
parent_virtual=field.virtual,
)
else:
self.already_checked.append(rel_part)
rel_part = ""
def _include_auto_related_models(self) -> None:
if self.auto_related:
new_joins = []
for join in self._select_related:
if not any([x.startswith(join) for x in self.auto_related]):
new_joins.append(join)
self._select_related = new_joins + self.auto_related
def _apply_expression_modifiers( def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select: ) -> sqlalchemy.sql.select:
@ -234,6 +156,4 @@ class Query:
self.select_from = None self.select_from = None
self.columns = None self.columns = None
self.order_bys = None self.order_bys = None
self.auto_related = []
self.used_aliases = [] self.used_aliases = []
self.already_checked = []

View File

@ -157,18 +157,7 @@ class QuerySet:
): ):
del new_kwargs[pkname] del new_kwargs[pkname]
# substitute related models with their pk new_kwargs = self.model_cls.substitute_models_with_pks(new_kwargs)
for field in self.model_cls._extract_related_names():
if field in new_kwargs and new_kwargs.get(field) is not None:
if isinstance(new_kwargs.get(field), ormar.Model):
new_kwargs[field] = getattr(
new_kwargs.get(field),
self.model_cls.Meta.model_fields[field].to.Meta.pkname,
)
else:
new_kwargs[field] = new_kwargs.get(field).get(
self.model_cls.Meta.model_fields[field].to.Meta.pkname
)
# Build the insert expression. # Build the insert expression.
expr = self.table.insert() expr = self.table.insert()

View File

@ -0,0 +1,87 @@
from typing import List, TYPE_CHECKING, Type
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover
from ormar import Model
class RelationshipCrawler:
def __init__(self) -> None:
self._select_related = []
self.auto_related = []
self.already_checked = []
def discover_relations(
self, select_related: List, prev_model: Type["Model"]
) -> List[str]:
self._select_related = select_related
self._extract_auto_required_relations(prev_model=prev_model)
self._include_auto_related_models()
return self._select_related
@staticmethod
def _field_is_a_foreign_key_and_no_circular_reference(
field: Type[BaseField], field_name: str, rel_part: str
) -> bool:
return issubclass(field, ForeignKeyField) and field_name not in rel_part
def _field_qualifies_to_deeper_search(
self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any(
[x.startswith(prev_part_of_related) for x in self._select_related]
)
already_checked = any(
[x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]
)
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def _extract_auto_required_relations(
self,
prev_model: Type["Model"],
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
) -> None:
for field_name, field in prev_model.Meta.model_fields.items():
if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part
):
rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable:
if rel_part not in self._select_related:
split_tables = rel_part.split("__")
new_related = (
"__".join(split_tables[:-1])
if len(split_tables) > 1
else rel_part
)
self.auto_related.append(new_related)
rel_part = ""
elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part
):
self._extract_auto_required_relations(
prev_model=field.to,
rel_part=rel_part,
nested=True,
parent_virtual=field.virtual,
)
else:
self.already_checked.append(rel_part)
rel_part = ""
def _include_auto_related_models(self) -> None:
if self.auto_related:
new_joins = []
for join in self._select_related:
if not any([x.startswith(join) for x in self.auto_related]):
new_joins.append(join)
self._select_related = new_joins + self.auto_related

View File

@ -5,7 +5,10 @@ from random import choices
from typing import List, TYPE_CHECKING, Union from typing import List, TYPE_CHECKING, Union
from weakref import proxy from weakref import proxy
from ormar.fields.foreign_key import ForeignKeyField import sqlalchemy
from sqlalchemy import text
from ormar.fields.foreign_key import ForeignKeyField # noqa I100
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from ormar.models import FakePydantic, Model from ormar.models import FakePydantic, Model
@ -20,6 +23,17 @@ class RelationshipManager:
self._relations = dict() self._relations = dict()
self._aliases = dict() self._aliases = dict()
@staticmethod
def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]:
return [
text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}")
for column in table.columns
]
@staticmethod
def prefixed_table_name(alias: str, name: str) -> text:
return text(f"{name} {alias}_{name}")
def add_relation_type( def add_relation_type(
self, self,
relations_key: str, relations_key: str,

View File

@ -51,7 +51,7 @@ setup(
packages=get_packages(PACKAGE), packages=get_packages(PACKAGE),
package_data={PACKAGE: ["py.typed"]}, package_data={PACKAGE: ["py.typed"]},
data_files=[("", ["LICENSE.md"])], data_files=[("", ["LICENSE.md"])],
install_requires=["databases", "pydantic", "sqlalchemy"], install_requires=["databases", "pydantic>=1.5", "sqlalchemy"],
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Environment :: Web Environment", "Environment :: Web Environment",