some cleanup and tests

This commit is contained in:
collerek
2020-08-23 12:54:58 +02:00
parent 08e251efdb
commit 53384879a9
17 changed files with 370 additions and 396 deletions

BIN
.coverage Normal file

Binary file not shown.

View File

@ -1,6 +1,5 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Any, List, Optional, TYPE_CHECKING
import pydantic
import sqlalchemy
from pydantic import Field
@ -10,13 +9,6 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import Model
def prepare_validator(type_):
def validate_model_field(value):
return isinstance(value, type_)
return validate_model_field
class BaseField:
__type__ = None
@ -34,13 +26,7 @@ class BaseField:
server_default: Any
@classmethod
def is_required(cls) -> bool:
return (
not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key()
)
@classmethod
def default_value(cls):
def default_value(cls) -> Optional[Field]:
if cls.is_auto_primary_key():
return Field(default=None)
if cls.has_default():
@ -52,7 +38,7 @@ class BaseField:
return None
@classmethod
def has_default(cls):
def has_default(cls) -> bool:
return cls.default is not None or cls.server_default is not None
@classmethod

View File

@ -1,27 +0,0 @@
from typing import Any, TYPE_CHECKING, Type
from ormar import ModelDefinitionError
if TYPE_CHECKING: # pragma no cover
from ormar.fields import BaseField
class RequiredParams:
def __init__(self, *args: str) -> None:
self._required = list(args)
def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]:
old_init = model_field_class.__init__
model_field_class._old_init = old_init
def __init__(instance: "BaseField", **kwargs: Any) -> None:
super(instance.__class__, instance).__init__(**kwargs)
for arg in self._required:
if arg not in kwargs:
raise ModelDefinitionError(
f"{instance.__class__.__name__} field requires parameter: {arg}"
)
setattr(instance, arg, kwargs.pop(arg))
model_field_class.__init__ = __init__
return model_field_class

View File

@ -1,7 +1,6 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy
from pydantic import BaseModel
import ormar # noqa I101
from ormar.exceptions import RelationshipInstanceError
@ -13,8 +12,7 @@ if TYPE_CHECKING: # pragma no cover
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
init_dict = {
**{fk.Meta.pkname: pk or -1,
'__pk_only__': True},
**{fk.Meta.pkname: pk or -1, "__pk_only__": True},
**{
k: create_dummy_instance(v.to)
for k, v in fk.Meta.model_fields.items()
@ -24,10 +22,15 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
return fk(**init_dict)
def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True,
related_name: str = None,
virtual: bool = False,
) -> Type[object]:
def ForeignKey(
to: "Model",
*,
name: str = None,
unique: bool = False,
nullable: bool = True,
related_name: str = None,
virtual: bool = False,
) -> Type[object]:
fk_string = to.Meta.tablename + "." + to.Meta.pkname
to_field = to.__fields__[to.Meta.pkname]
namespace = dict(
@ -43,7 +46,7 @@ def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = T
index=False,
pydantic_only=False,
default=None,
server_default=None
server_default=None,
)
return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
@ -59,21 +62,21 @@ class ForeignKeyField(BaseField):
yield cls.validate
@classmethod
def validate(cls, v: Any) -> Any:
return v
def validate(cls, value: Any) -> Any:
return value
@property
def __type__(self) -> Type[BaseModel]:
return self.to.__pydantic_model__
# @property
# def __type__(self) -> Type[BaseModel]:
# return self.to.__pydantic_model__
@classmethod
def get_column_type(cls) -> sqlalchemy.Column:
to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname]
return to_column.column_type
# @classmethod
# def get_column_type(cls) -> sqlalchemy.Column:
# to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname]
# return to_column.column_type
@classmethod
def _extract_model_from_sequence(
cls, value: List, child: "Model"
cls, value: List, child: "Model"
) -> Union["Model", List["Model"]]:
return [cls.expand_relationship(val, child) for val in value]
@ -109,7 +112,7 @@ class ForeignKeyField(BaseField):
@classmethod
def expand_relationship(
cls, value: Any, child: "Model"
cls, value: Any, child: "Model"
) -> Optional[Union["Model", List["Model"]]]:
if value is None:
return None

View File

@ -1,41 +1,43 @@
import datetime
import decimal
import re
from typing import Type, Any, Optional
from typing import Any, Optional, Type
import pydantic
import sqlalchemy
from pydantic import Json
from ormar import ModelDefinitionError
from ormar import ModelDefinitionError # noqa I101
from ormar.fields.base import BaseField # noqa I101
def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool:
def is_field_nullable(
nullable: Optional[bool], default: Any, server_default: Any
) -> bool:
if nullable is None:
return default is not None or server_default is not None
return nullable
def String(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
min_length: int = None,
max_length: int = None,
curtail_length: int = None,
regex: str = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
min_length: int = None,
max_length: int = None,
curtail_length: int = None,
regex: str = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[str]:
if max_length is None or max_length <= 0:
raise ModelDefinitionError(f'Parameter max_length is required for field String')
raise ModelDefinitionError("Parameter max_length is required for field String")
namespace = dict(
__type__=str,
@ -54,26 +56,26 @@ def String(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("String", (pydantic.ConstrainedStr, BaseField), namespace)
def Integer(
*,
name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[int]:
namespace = dict(
__type__=int,
@ -89,23 +91,23 @@ def Integer(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key
autoincrement=autoincrement if autoincrement is not None else primary_key,
)
return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace)
def Text(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
allow_blank: bool = False,
strip_whitespace: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[str]:
namespace = dict(
__type__=str,
@ -120,25 +122,25 @@ def Text(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("Text", (pydantic.ConstrainedStr, BaseField), namespace)
def Float(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[int]:
namespace = dict(
__type__=float,
@ -154,21 +156,21 @@ def Float(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace)
def Boolean(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[bool]:
namespace = dict(
__type__=bool,
@ -181,21 +183,21 @@ def Boolean(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("Boolean", (int, BaseField), namespace)
def DateTime(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.datetime]:
namespace = dict(
__type__=datetime.datetime,
@ -208,21 +210,21 @@ def DateTime(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("DateTime", (datetime.datetime, BaseField), namespace)
def Date(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.date]:
namespace = dict(
__type__=datetime.date,
@ -235,21 +237,21 @@ def Date(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("Date", (datetime.date, BaseField), namespace)
def Time(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[datetime.time]:
namespace = dict(
__type__=datetime.time,
@ -262,21 +264,21 @@ def Time(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("Time", (datetime.time, BaseField), namespace)
def JSON(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[Json]:
namespace = dict(
__type__=pydantic.Json,
@ -289,26 +291,26 @@ def JSON(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("JSON", (pydantic.Json, BaseField), namespace)
def BigInteger(
*,
name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
*,
name: str = None,
primary_key: bool = False,
autoincrement: bool = None,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: int = None,
maximum: int = None,
multiple_of: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[int]:
namespace = dict(
__type__=int,
@ -324,31 +326,33 @@ def BigInteger(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key
autoincrement=autoincrement if autoincrement is not None else primary_key,
)
return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace)
def Decimal(
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
precision: int = None,
scale: int = None,
max_digits: int = None,
decimal_places: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None
):
*,
name: str = None,
primary_key: bool = False,
nullable: bool = None,
index: bool = False,
unique: bool = False,
minimum: float = None,
maximum: float = None,
multiple_of: int = None,
precision: int = None,
scale: int = None,
max_digits: int = None,
decimal_places: int = None,
pydantic_only: bool = False,
default: Any = None,
server_default: Any = None,
) -> Type[decimal.Decimal]:
if precision is None or precision < 0 or scale is None or scale < 0:
raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal')
raise ModelDefinitionError(
"Parameters scale and precision are required for field Decimal"
)
namespace = dict(
__type__=decimal.Decimal,
@ -368,6 +372,6 @@ def Decimal(
pydantic_only=pydantic_only,
default=default,
server_default=server_default,
autoincrement=False
autoincrement=False,
)
return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace)

View File

@ -2,16 +2,17 @@ import inspect
import json
import uuid
from typing import (
AbstractSet,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Set,
TYPE_CHECKING,
Type,
TypeVar,
Union, AbstractSet, Mapping,
Union,
)
import databases
@ -20,10 +21,9 @@ import sqlalchemy
from pydantic import BaseModel
import ormar # noqa I100
from ormar import ForeignKey
from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMetaclass, ModelMeta
from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover
@ -39,7 +39,8 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
# FakePydantic inherits from list in order to be treated as
# request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts
__slots__ = ('_orm_id', '_orm_saved')
__slots__ = ("_orm_id", "_orm_saved")
__abstract__ = True
if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]]
@ -63,18 +64,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk")
kwargs = {
k: self.Meta.model_fields[k].expand_relationship(v, self)
k: self._convert_json(
k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps"
)
for k, v in kwargs.items()
}
values, fields_set, validation_error = pydantic.validate_model(
self, kwargs
)
values, fields_set, validation_error = pydantic.validate_model(self, kwargs)
if validation_error and not pk_only:
raise validation_error
object.__setattr__(self, '__dict__', values)
object.__setattr__(self, '__fields_set__', fields_set)
object.__setattr__(self, "__dict__", values)
object.__setattr__(self, "__fields_set__", fields_set)
# super().__init__(**kwargs)
# self.values = self.__pydantic_model__(**kwargs)
@ -82,58 +83,50 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def __del__(self) -> None:
self.Meta._orm_relationship_manager.deregister(self)
def __setattr__(self, name, value):
def __setattr__(self, name: str, value: Any) -> None:
relation_key = self.get_name(title=True) + "_" + name
if name in self.__slots__:
object.__setattr__(self, name, value)
elif name == 'pk':
elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value)
elif self.Meta._orm_relationship_manager.contains(relation_key, self):
self.Meta.model_fields[name].expand_relationship(value, self)
else:
super().__setattr__(name, value)
value = (
self._convert_json(name, value, "dumps")
if name in self.__fields__
else value
)
super().__setattr__(name, value)
def __getattr__(self, item):
def __getattribute__(self, item: str) -> Any:
if item != "__fields__" and item in self.__fields__:
related = self._extract_related_model_instead_of_field(item)
if related:
return related
value = object.__getattribute__(self, item)
value = self._convert_json(item, value, "loads")
return value
return super().__getattribute__(item)
def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
return self._extract_related_model_instead_of_field(item)
def _extract_related_model_instead_of_field(
self, item: str
) -> Optional[Union["Model", List["Model"]]]:
relation_key = self.get_name(title=True) + "_" + item
if self.Meta._orm_relationship_manager.contains(relation_key, self):
return self.Meta._orm_relationship_manager.get(relation_key, self)
# def __setattr__(self, key: str, value: Any) -> None:
# if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'):
# return setattr(self, key, value)
# # elif key in self._extract_related_names():
# # value = self._convert_json(key, value, op="dumps")
# # value = self.Meta.model_fields[key].expand_relationship(value, self)
# # relation_key = self.get_name(title=True) + "_" + key
# # if not self.Meta._orm_relationship_manager.contains(relation_key, self):
# # setattr(self.values, key, value)
# else:
# super().__setattr__(key, value)
# def __getattribute__(self, key: str) -> Any:
# if key != 'Meta' and key in self.Meta.model_fields:
# relation_key = self.get_name(title=True) + "_" + key
# if self.Meta._orm_relationship_manager.contains(relation_key, self):
# return self.Meta._orm_relationship_manager.get(relation_key, self)
# item = getattr(self.__fields__, key, None)
# item = self._convert_json(key, item, op="loads")
# return item
# return super().__getattribute__(key)
def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover
return False
return (self._orm_id == other._orm_id or
self.__dict__ == other.__dict__ or
(self.pk == other.pk and self.pk is not None
))
# def __repr__(self) -> str: # pragma no cover
# return self.values.__repr__()
# @classmethod
# def __get_validators__(cls) -> Callable: # pragma no cover
# yield cls.__pydantic_model__.validate
return (
self._orm_id == other._orm_id
or self.__dict__ == other.__dict__
or (self.pk == other.pk and self.pk is not None)
)
@classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str:
@ -148,10 +141,6 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def pk(self) -> Any:
return getattr(self, self.Meta.pkname)
@pk.setter
def pk(self, value: Any) -> None:
setattr(self, self.Meta.pkname, value)
@property
def pk_column(self) -> sqlalchemy.Column:
return self.Meta.table.primary_key.columns.values()[0]
@ -160,36 +149,38 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def pk_type(cls) -> Any:
return cls.Meta.model_fields[cls.Meta.pkname].__type__
def dict(
self,
*,
include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None,
exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None,
by_alias: bool = False,
skip_defaults: bool = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
nested: bool = False
) -> 'DictStrAny': # noqa: A003'
dict_instance = super().dict(include=include,
exclude=self._exclude_related_names_not_required(nested),
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none)
def dict( # noqa A003
self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
skip_defaults: bool = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
nested: bool = False
) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict(
include=include,
exclude=self._exclude_related_names_not_required(nested),
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for field in self._extract_related_names():
nested_model = getattr(self, field)
if self.Meta.model_fields[field].virtual and nested:
continue
if isinstance(nested_model, list) and not isinstance(
nested_model, ormar.Model
nested_model, ormar.Model
):
dict_instance[field] = [x.dict(nested=True) for x in nested_model]
elif nested_model is not None:
dict_instance[field] = nested_model.dict(nested=True)
dict_instance[field] = nested_model.dict(nested=True)
return dict_instance
def from_dict(self, value_dict: Dict) -> None:
@ -225,19 +216,21 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def _extract_related_names(cls) -> Set:
related_names = set()
for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(
field, ForeignKeyField
):
if inspect.isclass(field) and issubclass(field, ForeignKeyField):
related_names.add(name)
return related_names
@classmethod
def _exclude_related_names_not_required(cls, nested:bool=False) -> Set:
def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested:
return cls._extract_related_names()
related_names = set()
for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable:
if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and field.nullable
):
related_names.add(name)
return related_names
@ -267,7 +260,7 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.Meta.model_fields.keys():
if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), ormar.Model
getattr(one, field), ormar.Model
):
setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), ormar.Model):

View File

@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
import databases
import pydantic
import sqlalchemy
from pydantic import BaseConfig, create_model, Extra
from pydantic.fields import ModelField, FieldInfo
from pydantic import BaseConfig
from pydantic.fields import FieldInfo
from ormar import ForeignKey, ModelDefinitionError # noqa I100
from ormar.fields import BaseField
@ -29,23 +29,11 @@ class ModelMeta:
_orm_relationship_manager: RelationshipManager
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
pydantic_fields = {
field_name: (
base_field.__type__,
... if base_field.is_required else base_field.default_value,
)
for field_name, base_field in object_dict.items()
if isinstance(base_field, BaseField)
}
return pydantic_fields
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
child_relation_name = (
field.to.get_name(title=True)
+ "_"
+ (field.related_name or (name.lower() + "s"))
field.to.get_name(title=True)
+ "_"
+ (field.related_name or (name.lower() + "s"))
)
reverse_name = child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name()
@ -61,34 +49,28 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
parent_model = model_field.to
child = model
if (
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__
):
register_reverse_model_fields(parent_model, child, child_model_name)
def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str
model: Type["Model"], child: Type["Model"], child_model_name: str
) -> None:
# model.__fields__[child_model_name] = ModelField(
# name=child_model_name,
# type_=Optional[Union[List[child], child]],
# model_config=child.__config__,
# class_validators=child.__validators__,
# )
model.Meta.model_fields[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True
)
def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str
name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
columns = []
pkname = None
model_fields = {
field_name: field
for field_name, field in object_dict['__annotations__'].items()
for field_name, field in object_dict["__annotations__"].items()
if issubclass(field, BaseField)
}
for field_name, field in model_fields.items():
@ -96,7 +78,7 @@ def sqlalchemy_columns_from_model_fields(
if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only:
raise ModelDefinitionError('Primary key column cannot be pydantic only')
raise ModelDefinitionError("Primary key column cannot be pydantic only")
pkname = field_name
if not field.pydantic_only:
columns.append(field.get_column(field_name))
@ -112,10 +94,10 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict:
if type_.name is None:
type_.name = field
def_value = type_.default_value()
curr_def_value = attrs.get(field, 'NONE')
if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo):
curr_def_value = attrs.get(field, "NONE")
if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
attrs[field] = def_value
elif curr_def_value == 'NONE' and type_.nullable:
elif curr_def_value == "NONE" and type_.nullable:
attrs[field] = FieldInfo(default=None)
return attrs
@ -132,18 +114,15 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
attrs['Config'] = get_pydantic_base_orm_config()
attrs["Config"] = get_pydantic_base_orm_config()
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs
)
if hasattr(new_model, 'Meta'):
if attrs.get("__abstract__"):
return new_model
if hasattr(new_model, "Meta"):
annotations = attrs.get("__annotations__") or new_model.__annotations__
attrs["__annotations__"]= annotations
attrs["__annotations__"] = annotations
attrs = populate_pydantic_default_values(attrs)
tablename = name.lower() + "s"
@ -152,18 +131,20 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
# sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
name, attrs, new_model.Meta.tablename
name, attrs, new_model.Meta.tablename
)
if hasattr(new_model.Meta, "model_fields") and not pkname:
model_fields = new_model.Meta.model_fields
for fieldname, field in new_model.Meta.model_fields.items():
if field.primary_key:
pkname=fieldname
columns = new_model.Meta.table.columns
model_fields = new_model.Meta.model_fields
for fieldname, field in new_model.Meta.model_fields.items():
if field.primary_key:
pkname = fieldname
columns = new_model.Meta.table.columns
if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns)
new_model.Meta.table = sqlalchemy.Table(
new_model.Meta.tablename, new_model.Meta.metadata, *columns
)
new_model.Meta.columns = columns
new_model.Meta.pkname = pkname
@ -171,12 +152,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
if not pkname:
raise ModelDefinitionError("Table has to have a primary key.")
# pydantic model creation
new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
new_model.Meta.pydantic_model = create_model(
name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields
)
new_model.Meta.model_fields = model_fields
new_model = super().__new__( # type: ignore
mcs, name, bases, attrs

View File

@ -13,10 +13,10 @@ class Model(FakePydantic):
@classmethod
def from_row(
cls,
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
previous_table: str = None,
cls,
row: sqlalchemy.engine.ResultProxy,
select_related: List = None,
previous_table: str = None,
) -> "Model":
item = {}
@ -66,8 +66,8 @@ class Model(FakePydantic):
self_fields.pop(self.Meta.pkname)
expr = (
self.Meta.table.update()
.values(**self_fields)
.where(self.pk_column == getattr(self, self.Meta.pkname))
.values(**self_fields)
.where(self.pk_column == getattr(self, self.Meta.pkname))
)
result = await self.Meta.database.execute(expr)
return result

View File

@ -20,12 +20,12 @@ class JoinParameters(NamedTuple):
class Query:
def __init__(
self,
model_cls: Type["Model"],
filter_clauses: List,
select_related: List,
limit_count: int,
offset: int,
self,
model_cls: Type["Model"],
filter_clauses: List,
select_related: List,
limit_count: int,
offset: int,
) -> None:
self.query_offset = offset
@ -49,15 +49,15 @@ class Query:
self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
self.select_from = self.table
for key in self.model_cls.Meta.model_fields:
if (
not self.model_cls.Meta.model_fields[key].nullable
and isinstance(
self.model_cls.Meta.model_fields[key], ForeignKeyField,
)
and key not in self._select_related
):
self._select_related = [key] + self._select_related
# for key in self.model_cls.Meta.model_fields:
# if (
# not self.model_cls.Meta.model_fields[key].nullable
# and isinstance(
# self.model_cls.Meta.model_fields[key], ForeignKeyField,
# )
# and key not in self._select_related
# ):
# self._select_related = [key] + self._select_related
start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls
@ -79,7 +79,7 @@ class Query:
expr = self._apply_expression_modifiers(expr)
#print(expr.compile(compile_kwargs={"literal_binds": True}))
# print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters()
return expr, self._select_related
@ -97,12 +97,12 @@ class Query:
@staticmethod
def _field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str
field: Type[BaseField], field_name: str, rel_part: str
) -> bool:
return issubclass(field, ForeignKeyField) and field_name not in rel_part
def _field_qualifies_to_deeper_search(
self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any(
@ -112,19 +112,19 @@ class Query:
[x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]
)
return (
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
(field.virtual and parent_virtual)
or (partial_match and not already_checked)
) or not nested
def on_clause(
self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text:
left_part = f"{alias}_{to_clause}"
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}")
def _build_join_parameters(
self, part: str, join_params: JoinParameters
self, part: str, join_params: JoinParameters
) -> JoinParameters:
model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name
@ -138,7 +138,8 @@ class Query:
(
v
for k, v in model_cls.Meta.model_fields.items()
if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model
if issubclass(v, ForeignKeyField)
and v.to == join_params.prev_model
),
None,
).name
@ -167,25 +168,28 @@ class Query:
return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _extract_auto_required_relations(
self,
prev_model: Type["Model"],
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
self,
prev_model: Type["Model"],
rel_part: str = "",
nested: bool = False,
parent_virtual: bool = False,
) -> None:
for field_name, field in prev_model.Meta.model_fields.items():
if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part
field, field_name, rel_part
):
rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable:
if rel_part not in self._select_related:
new_related = "__".join(rel_part.split("__")[:-1]) if len(
rel_part.split("__")) > 1 else rel_part
new_related = (
"__".join(rel_part.split("__")[:-1])
if len(rel_part.split("__")) > 1
else rel_part
)
self.auto_related.append(new_related)
rel_part = ""
elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part
field, parent_virtual, nested, rel_part
):
self._extract_auto_required_relations(
@ -207,7 +211,7 @@ class Query:
self._select_related = new_joins + self.auto_related
def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select
self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select:
if self.filter_clauses:
if len(self.filter_clauses) == 1:

View File

@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover
class QuerySet:
def __init__(
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
self,
model_cls: Type["Model"] = None,
filter_clauses: List = None,
select_related: List = None,
limit_count: int = None,
offset: int = None,
) -> None:
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -151,9 +151,9 @@ class QuerySet:
pkname = self.model_cls.Meta.pkname
pk = self.model_cls.Meta.model_fields[pkname]
if (
pkname in new_kwargs
and new_kwargs.get(pkname) is None
and (pk.nullable or pk.autoincrement)
pkname in new_kwargs
and new_kwargs.get(pkname) is None
and (pk.nullable or pk.autoincrement)
):
del new_kwargs[pkname]

View File

@ -5,7 +5,6 @@ from random import choices
from typing import List, TYPE_CHECKING, Union
from weakref import proxy
from ormar import ForeignKey
from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover
@ -22,7 +21,11 @@ class RelationshipManager:
self._aliases = dict()
def add_relation_type(
self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str
self,
relations_key: str,
reverse_key: str,
field: ForeignKeyField,
table_name: str,
) -> None:
if relations_key not in self._relations:
self._relations[relations_key] = {"type": "primary"}

0
scripts/clean.sh Normal file → Executable file
View File

0
scripts/publish.sh Normal file → Executable file
View File

0
scripts/test.sh Normal file → Executable file
View File

View File

@ -5,6 +5,7 @@ from pydantic import ValidationError
import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
@ -120,6 +121,7 @@ async def test_model_crud():
track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model)
assert track.album.name is None
await track.album.load()
assert track.album.name == "Malibu"

View File

@ -75,9 +75,18 @@ def test_model_attribute_access(example):
example.test = 12
assert example.test == 12
example._orm_saved = True
assert example._orm_saved
def test_model_attribute_json_access(example):
example.test_json = dict(aa=12)
assert example.test_json == dict(aa=12)
def test_non_existing_attr(example):
with pytest.raises(ValueError):
example.new_attr=12
with pytest.raises(ValueError):
example.new_attr = 12
def test_primary_key_access_and_setting(example):

View File

@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class JsonSample(ormar.Model):
class Meta:
tablename = "jsons"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
test_json: ormar.JSON(nullable=True)
class User(ormar.Model):
class Meta:
tablename = "users"
@ -56,6 +66,18 @@ def test_model_pk():
assert user.id == 1
@pytest.mark.asyncio
async def test_json_column():
async with database:
await JsonSample.objects.create(test_json=dict(aa=12))
await JsonSample.objects.create(test_json='{"aa": 12}')
items = await JsonSample.objects.all()
assert len(items) == 2
assert items[0].test_json == dict(aa=12)
assert items[1].test_json == dict(aa=12)
@pytest.mark.asyncio
async def test_model_crud():
async with database: