some cleanup and tests

This commit is contained in:
collerek
2020-08-23 12:54:58 +02:00
parent 08e251efdb
commit 53384879a9
17 changed files with 370 additions and 396 deletions

BIN
.coverage Normal file

Binary file not shown.

View File

@ -1,6 +1,5 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING from typing import Any, List, Optional, TYPE_CHECKING
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Field from pydantic import Field
@ -10,13 +9,6 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model
def prepare_validator(type_):
def validate_model_field(value):
return isinstance(value, type_)
return validate_model_field
class BaseField: class BaseField:
__type__ = None __type__ = None
@ -34,13 +26,7 @@ class BaseField:
server_default: Any server_default: Any
@classmethod @classmethod
def is_required(cls) -> bool: def default_value(cls) -> Optional[Field]:
return (
not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key()
)
@classmethod
def default_value(cls):
if cls.is_auto_primary_key(): if cls.is_auto_primary_key():
return Field(default=None) return Field(default=None)
if cls.has_default(): if cls.has_default():
@ -52,7 +38,7 @@ class BaseField:
return None return None
@classmethod @classmethod
def has_default(cls): def has_default(cls) -> bool:
return cls.default is not None or cls.server_default is not None return cls.default is not None or cls.server_default is not None
@classmethod @classmethod

View File

@ -1,27 +0,0 @@
from typing import Any, TYPE_CHECKING, Type
from ormar import ModelDefinitionError
if TYPE_CHECKING: # pragma no cover
from ormar.fields import BaseField
class RequiredParams:
def __init__(self, *args: str) -> None:
self._required = list(args)
def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]:
old_init = model_field_class.__init__
model_field_class._old_init = old_init
def __init__(instance: "BaseField", **kwargs: Any) -> None:
super(instance.__class__, instance).__init__(**kwargs)
for arg in self._required:
if arg not in kwargs:
raise ModelDefinitionError(
f"{instance.__class__.__name__} field requires parameter: {arg}"
)
setattr(instance, arg, kwargs.pop(arg))
model_field_class.__init__ = __init__
return model_field_class

View File

@ -1,7 +1,6 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy import sqlalchemy
from pydantic import BaseModel
import ormar # noqa I101 import ormar # noqa I101
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
@ -13,8 +12,7 @@ if TYPE_CHECKING: # pragma no cover
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
init_dict = { init_dict = {
**{fk.Meta.pkname: pk or -1, **{fk.Meta.pkname: pk or -1, "__pk_only__": True},
'__pk_only__': True},
**{ **{
k: create_dummy_instance(v.to) k: create_dummy_instance(v.to)
for k, v in fk.Meta.model_fields.items() for k, v in fk.Meta.model_fields.items()
@ -24,7 +22,12 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
return fk(**init_dict) return fk(**init_dict)
def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True, def ForeignKey(
to: "Model",
*,
name: str = None,
unique: bool = False,
nullable: bool = True,
related_name: str = None, related_name: str = None,
virtual: bool = False, virtual: bool = False,
) -> Type[object]: ) -> Type[object]:
@ -43,7 +46,7 @@ def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = T
index=False, index=False,
pydantic_only=False, pydantic_only=False,
default=None, default=None,
server_default=None server_default=None,
) )
return type("ForeignKey", (ForeignKeyField, BaseField), namespace) return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
@ -59,17 +62,17 @@ class ForeignKeyField(BaseField):
yield cls.validate yield cls.validate
@classmethod @classmethod
def validate(cls, v: Any) -> Any: def validate(cls, value: Any) -> Any:
return v return value
@property # @property
def __type__(self) -> Type[BaseModel]: # def __type__(self) -> Type[BaseModel]:
return self.to.__pydantic_model__ # return self.to.__pydantic_model__
@classmethod # @classmethod
def get_column_type(cls) -> sqlalchemy.Column: # def get_column_type(cls) -> sqlalchemy.Column:
to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] # to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname]
return to_column.column_type # return to_column.column_type
@classmethod @classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(

View File

@ -1,17 +1,19 @@
import datetime import datetime
import decimal import decimal
import re import re
from typing import Type, Any, Optional from typing import Any, Optional, Type
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Json from pydantic import Json
from ormar import ModelDefinitionError from ormar import ModelDefinitionError # noqa I101
from ormar.fields.base import BaseField # noqa I101 from ormar.fields.base import BaseField # noqa I101
def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool: def is_field_nullable(
nullable: Optional[bool], default: Any, server_default: Any
) -> bool:
if nullable is None: if nullable is None:
return default is not None or server_default is not None return default is not None or server_default is not None
return nullable return nullable
@ -32,10 +34,10 @@ def String(
regex: str = None, regex: str = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[str]: ) -> Type[str]:
if max_length is None or max_length <= 0: if max_length is None or max_length <= 0:
raise ModelDefinitionError(f'Parameter max_length is required for field String') raise ModelDefinitionError("Parameter max_length is required for field String")
namespace = dict( namespace = dict(
__type__=str, __type__=str,
@ -54,7 +56,7 @@ def String(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("String", (pydantic.ConstrainedStr, BaseField), namespace) return type("String", (pydantic.ConstrainedStr, BaseField), namespace)
@ -73,7 +75,7 @@ def Integer(
multiple_of: int = None, multiple_of: int = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[int]: ) -> Type[int]:
namespace = dict( namespace = dict(
__type__=int, __type__=int,
@ -89,7 +91,7 @@ def Integer(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key autoincrement=autoincrement if autoincrement is not None else primary_key,
) )
return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace)
@ -105,7 +107,7 @@ def Text(
strip_whitespace: bool = False, strip_whitespace: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[str]: ) -> Type[str]:
namespace = dict( namespace = dict(
__type__=str, __type__=str,
@ -120,7 +122,7 @@ def Text(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) return type("Text", (pydantic.ConstrainedStr, BaseField), namespace)
@ -138,7 +140,7 @@ def Float(
multiple_of: int = None, multiple_of: int = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[int]: ) -> Type[int]:
namespace = dict( namespace = dict(
__type__=float, __type__=float,
@ -154,7 +156,7 @@ def Float(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace)
@ -168,7 +170,7 @@ def Boolean(
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[bool]: ) -> Type[bool]:
namespace = dict( namespace = dict(
__type__=bool, __type__=bool,
@ -181,7 +183,7 @@ def Boolean(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Boolean", (int, BaseField), namespace) return type("Boolean", (int, BaseField), namespace)
@ -195,7 +197,7 @@ def DateTime(
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[datetime.datetime]: ) -> Type[datetime.datetime]:
namespace = dict( namespace = dict(
__type__=datetime.datetime, __type__=datetime.datetime,
@ -208,7 +210,7 @@ def DateTime(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("DateTime", (datetime.datetime, BaseField), namespace) return type("DateTime", (datetime.datetime, BaseField), namespace)
@ -222,7 +224,7 @@ def Date(
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[datetime.date]: ) -> Type[datetime.date]:
namespace = dict( namespace = dict(
__type__=datetime.date, __type__=datetime.date,
@ -235,7 +237,7 @@ def Date(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Date", (datetime.date, BaseField), namespace) return type("Date", (datetime.date, BaseField), namespace)
@ -249,7 +251,7 @@ def Time(
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[datetime.time]: ) -> Type[datetime.time]:
namespace = dict( namespace = dict(
__type__=datetime.time, __type__=datetime.time,
@ -262,7 +264,7 @@ def Time(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Time", (datetime.time, BaseField), namespace) return type("Time", (datetime.time, BaseField), namespace)
@ -276,7 +278,7 @@ def JSON(
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[Json]: ) -> Type[Json]:
namespace = dict( namespace = dict(
__type__=pydantic.Json, __type__=pydantic.Json,
@ -289,7 +291,7 @@ def JSON(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("JSON", (pydantic.Json, BaseField), namespace) return type("JSON", (pydantic.Json, BaseField), namespace)
@ -308,7 +310,7 @@ def BigInteger(
multiple_of: int = None, multiple_of: int = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[int]: ) -> Type[int]:
namespace = dict( namespace = dict(
__type__=int, __type__=int,
@ -324,7 +326,7 @@ def BigInteger(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key autoincrement=autoincrement if autoincrement is not None else primary_key,
) )
return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace)
@ -345,10 +347,12 @@ def Decimal(
decimal_places: int = None, decimal_places: int = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
): ) -> Type[decimal.Decimal]:
if precision is None or precision < 0 or scale is None or scale < 0: if precision is None or precision < 0 or scale is None or scale < 0:
raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal') raise ModelDefinitionError(
"Parameters scale and precision are required for field Decimal"
)
namespace = dict( namespace = dict(
__type__=decimal.Decimal, __type__=decimal.Decimal,
@ -368,6 +372,6 @@ def Decimal(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace)

View File

@ -2,16 +2,17 @@ import inspect
import json import json
import uuid import uuid
from typing import ( from typing import (
AbstractSet,
Any, Any,
Callable,
Dict, Dict,
List, List,
Mapping,
Optional, Optional,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
Union, AbstractSet, Mapping, Union,
) )
import databases import databases
@ -20,10 +21,9 @@ import sqlalchemy
from pydantic import BaseModel from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar import ForeignKey
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMetaclass, ModelMeta from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.relations import RelationshipManager from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -39,7 +39,8 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
# FakePydantic inherits from list in order to be treated as # FakePydantic inherits from list in order to be treated as
# request.Body parameter in fastapi routes, # request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts # inheriting from pydantic.BaseModel causes metaclass conflicts
__slots__ = ('_orm_id', '_orm_saved') __slots__ = ("_orm_id", "_orm_saved")
__abstract__ = True
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]] __model_fields__: Dict[str, TypeVar[BaseField]]
@ -63,18 +64,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
kwargs = { kwargs = {
k: self.Meta.model_fields[k].expand_relationship(v, self) k: self._convert_json(
k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps"
)
for k, v in kwargs.items() for k, v in kwargs.items()
} }
values, fields_set, validation_error = pydantic.validate_model( values, fields_set, validation_error = pydantic.validate_model(self, kwargs)
self, kwargs
)
if validation_error and not pk_only: if validation_error and not pk_only:
raise validation_error raise validation_error
object.__setattr__(self, '__dict__', values) object.__setattr__(self, "__dict__", values)
object.__setattr__(self, '__fields_set__', fields_set) object.__setattr__(self, "__fields_set__", fields_set)
# super().__init__(**kwargs) # super().__init__(**kwargs)
# self.values = self.__pydantic_model__(**kwargs) # self.values = self.__pydantic_model__(**kwargs)
@ -82,58 +83,50 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def __del__(self) -> None: def __del__(self) -> None:
self.Meta._orm_relationship_manager.deregister(self) self.Meta._orm_relationship_manager.deregister(self)
def __setattr__(self, name, value): def __setattr__(self, name: str, value: Any) -> None:
relation_key = self.get_name(title=True) + "_" + name relation_key = self.get_name(title=True) + "_" + name
if name in self.__slots__: if name in self.__slots__:
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == 'pk': elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
elif self.Meta._orm_relationship_manager.contains(relation_key, self): elif self.Meta._orm_relationship_manager.contains(relation_key, self):
self.Meta.model_fields[name].expand_relationship(value, self) self.Meta.model_fields[name].expand_relationship(value, self)
else: else:
value = (
self._convert_json(name, value, "dumps")
if name in self.__fields__
else value
)
super().__setattr__(name, value) super().__setattr__(name, value)
def __getattr__(self, item): def __getattribute__(self, item: str) -> Any:
if item != "__fields__" and item in self.__fields__:
related = self._extract_related_model_instead_of_field(item)
if related:
return related
value = object.__getattribute__(self, item)
value = self._convert_json(item, value, "loads")
return value
return super().__getattribute__(item)
def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
return self._extract_related_model_instead_of_field(item)
def _extract_related_model_instead_of_field(
self, item: str
) -> Optional[Union["Model", List["Model"]]]:
relation_key = self.get_name(title=True) + "_" + item relation_key = self.get_name(title=True) + "_" + item
if self.Meta._orm_relationship_manager.contains(relation_key, self): if self.Meta._orm_relationship_manager.contains(relation_key, self):
return self.Meta._orm_relationship_manager.get(relation_key, self) return self.Meta._orm_relationship_manager.get(relation_key, self)
# def __setattr__(self, key: str, value: Any) -> None:
# if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'):
# return setattr(self, key, value)
# # elif key in self._extract_related_names():
# # value = self._convert_json(key, value, op="dumps")
# # value = self.Meta.model_fields[key].expand_relationship(value, self)
# # relation_key = self.get_name(title=True) + "_" + key
# # if not self.Meta._orm_relationship_manager.contains(relation_key, self):
# # setattr(self.values, key, value)
# else:
# super().__setattr__(key, value)
# def __getattribute__(self, key: str) -> Any:
# if key != 'Meta' and key in self.Meta.model_fields:
# relation_key = self.get_name(title=True) + "_" + key
# if self.Meta._orm_relationship_manager.contains(relation_key, self):
# return self.Meta._orm_relationship_manager.get(relation_key, self)
# item = getattr(self.__fields__, key, None)
# item = self._convert_json(key, item, op="loads")
# return item
# return super().__getattribute__(key)
def __same__(self, other: "Model") -> bool: def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover if self.__class__ != other.__class__: # pragma no cover
return False return False
return (self._orm_id == other._orm_id or return (
self.__dict__ == other.__dict__ or self._orm_id == other._orm_id
(self.pk == other.pk and self.pk is not None or self.__dict__ == other.__dict__
)) or (self.pk == other.pk and self.pk is not None)
)
# def __repr__(self) -> str: # pragma no cover
# return self.values.__repr__()
# @classmethod
# def __get_validators__(cls) -> Callable: # pragma no cover
# yield cls.__pydantic_model__.validate
@classmethod @classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str: def get_name(cls, title: bool = False, lower: bool = True) -> str:
@ -148,10 +141,6 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def pk(self) -> Any: def pk(self) -> Any:
return getattr(self, self.Meta.pkname) return getattr(self, self.Meta.pkname)
@pk.setter
def pk(self, value: Any) -> None:
setattr(self, self.Meta.pkname, value)
@property @property
def pk_column(self) -> sqlalchemy.Column: def pk_column(self) -> sqlalchemy.Column:
return self.Meta.table.primary_key.columns.values()[0] return self.Meta.table.primary_key.columns.values()[0]
@ -160,25 +149,27 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def pk_type(cls) -> Any: def pk_type(cls) -> Any:
return cls.Meta.model_fields[cls.Meta.pkname].__type__ return cls.Meta.model_fields[cls.Meta.pkname].__type__
def dict( def dict( # noqa A003
self, self,
*, *,
include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False, by_alias: bool = False,
skip_defaults: bool = None, skip_defaults: bool = None,
exclude_unset: bool = False, exclude_unset: bool = False,
exclude_defaults: bool = False, exclude_defaults: bool = False,
exclude_none: bool = False, exclude_none: bool = False,
nested: bool = False nested: bool = False
) -> 'DictStrAny': # noqa: A003' ) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict(include=include, dict_instance = super().dict(
include=include,
exclude=self._exclude_related_names_not_required(nested), exclude=self._exclude_related_names_not_required(nested),
by_alias=by_alias, by_alias=by_alias,
skip_defaults=skip_defaults, skip_defaults=skip_defaults,
exclude_unset=exclude_unset, exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults, exclude_defaults=exclude_defaults,
exclude_none=exclude_none) exclude_none=exclude_none,
)
for field in self._extract_related_names(): for field in self._extract_related_names():
nested_model = getattr(self, field) nested_model = getattr(self, field)
@ -225,9 +216,7 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def _extract_related_names(cls) -> Set: def _extract_related_names(cls) -> Set:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass( if inspect.isclass(field) and issubclass(field, ForeignKeyField):
field, ForeignKeyField
):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -237,7 +226,11 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
return cls._extract_related_names() return cls._extract_related_names()
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable: if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and field.nullable
):
related_names.add(name) related_names.add(name)
return related_names return related_names

View File

@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
import databases import databases
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig, create_model, Extra from pydantic import BaseConfig
from pydantic.fields import ModelField, FieldInfo from pydantic.fields import FieldInfo
from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar import ForeignKey, ModelDefinitionError # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
@ -29,18 +29,6 @@ class ModelMeta:
_orm_relationship_manager: RelationshipManager _orm_relationship_manager: RelationshipManager
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
pydantic_fields = {
field_name: (
base_field.__type__,
... if base_field.is_required else base_field.default_value,
)
for field_name, base_field in object_dict.items()
if isinstance(base_field, BaseField)
}
return pydantic_fields
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
child_relation_name = ( child_relation_name = (
field.to.get_name(title=True) field.to.get_name(title=True)
@ -70,12 +58,6 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
def register_reverse_model_fields( def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str model: Type["Model"], child: Type["Model"], child_model_name: str
) -> None: ) -> None:
# model.__fields__[child_model_name] = ModelField(
# name=child_model_name,
# type_=Optional[Union[List[child], child]],
# model_config=child.__config__,
# class_validators=child.__validators__,
# )
model.Meta.model_fields[child_model_name] = ForeignKey( model.Meta.model_fields[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True child, name=child_model_name, virtual=True
) )
@ -88,7 +70,7 @@ def sqlalchemy_columns_from_model_fields(
pkname = None pkname = None
model_fields = { model_fields = {
field_name: field field_name: field
for field_name, field in object_dict['__annotations__'].items() for field_name, field in object_dict["__annotations__"].items()
if issubclass(field, BaseField) if issubclass(field, BaseField)
} }
for field_name, field in model_fields.items(): for field_name, field in model_fields.items():
@ -96,7 +78,7 @@ def sqlalchemy_columns_from_model_fields(
if pkname is not None: if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.") raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only: if field.pydantic_only:
raise ModelDefinitionError('Primary key column cannot be pydantic only') raise ModelDefinitionError("Primary key column cannot be pydantic only")
pkname = field_name pkname = field_name
if not field.pydantic_only: if not field.pydantic_only:
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
@ -112,10 +94,10 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict:
if type_.name is None: if type_.name is None:
type_.name = field type_.name = field
def_value = type_.default_value() def_value = type_.default_value()
curr_def_value = attrs.get(field, 'NONE') curr_def_value = attrs.get(field, "NONE")
if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo): if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
attrs[field] = def_value attrs[field] = def_value
elif curr_def_value == 'NONE' and type_.nullable: elif curr_def_value == "NONE" and type_.nullable:
attrs[field] = FieldInfo(default=None) attrs[field] = FieldInfo(default=None)
return attrs return attrs
@ -132,15 +114,12 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
attrs['Config'] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
if hasattr(new_model, 'Meta'): if hasattr(new_model, "Meta"):
if attrs.get("__abstract__"):
return new_model
annotations = attrs.get("__annotations__") or new_model.__annotations__ annotations = attrs.get("__annotations__") or new_model.__annotations__
attrs["__annotations__"] = annotations attrs["__annotations__"] = annotations
@ -163,7 +142,9 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
columns = new_model.Meta.table.columns columns = new_model.Meta.table.columns
if not hasattr(new_model.Meta, "table"): if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns) new_model.Meta.table = sqlalchemy.Table(
new_model.Meta.tablename, new_model.Meta.metadata, *columns
)
new_model.Meta.columns = columns new_model.Meta.columns = columns
new_model.Meta.pkname = pkname new_model.Meta.pkname = pkname
@ -171,12 +152,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
if not pkname: if not pkname:
raise ModelDefinitionError("Table has to have a primary key.") raise ModelDefinitionError("Table has to have a primary key.")
# pydantic model creation
new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
new_model.Meta.pydantic_model = create_model(
name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields
)
new_model.Meta.model_fields = model_fields new_model.Meta.model_fields = model_fields
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs

View File

@ -49,15 +49,15 @@ class Query:
self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
self.select_from = self.table self.select_from = self.table
for key in self.model_cls.Meta.model_fields: # for key in self.model_cls.Meta.model_fields:
if ( # if (
not self.model_cls.Meta.model_fields[key].nullable # not self.model_cls.Meta.model_fields[key].nullable
and isinstance( # and isinstance(
self.model_cls.Meta.model_fields[key], ForeignKeyField, # self.model_cls.Meta.model_fields[key], ForeignKeyField,
) # )
and key not in self._select_related # and key not in self._select_related
): # ):
self._select_related = [key] + self._select_related # self._select_related = [key] + self._select_related
start_params = JoinParameters( start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls self.model_cls, "", self.table.name, self.model_cls
@ -97,7 +97,7 @@ class Query:
@staticmethod @staticmethod
def _field_is_a_foreign_key_and_no_circular_reference( def _field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str field: Type[BaseField], field_name: str, rel_part: str
) -> bool: ) -> bool:
return issubclass(field, ForeignKeyField) and field_name not in rel_part return issubclass(field, ForeignKeyField) and field_name not in rel_part
@ -138,7 +138,8 @@ class Query:
( (
v v
for k, v in model_cls.Meta.model_fields.items() for k, v in model_cls.Meta.model_fields.items()
if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model if issubclass(v, ForeignKeyField)
and v.to == join_params.prev_model
), ),
None, None,
).name ).name
@ -180,8 +181,11 @@ class Query:
rel_part = field_name if not rel_part else rel_part + "__" + field_name rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable: if not field.nullable:
if rel_part not in self._select_related: if rel_part not in self._select_related:
new_related = "__".join(rel_part.split("__")[:-1]) if len( new_related = (
rel_part.split("__")) > 1 else rel_part "__".join(rel_part.split("__")[:-1])
if len(rel_part.split("__")) > 1
else rel_part
)
self.auto_related.append(new_related) self.auto_related.append(new_related)
rel_part = "" rel_part = ""
elif self._field_qualifies_to_deeper_search( elif self._field_qualifies_to_deeper_search(

View File

@ -5,7 +5,6 @@ from random import choices
from typing import List, TYPE_CHECKING, Union from typing import List, TYPE_CHECKING, Union
from weakref import proxy from weakref import proxy
from ormar import ForeignKey
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -22,7 +21,11 @@ class RelationshipManager:
self._aliases = dict() self._aliases = dict()
def add_relation_type( def add_relation_type(
self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str self,
relations_key: str,
reverse_key: str,
field: ForeignKeyField,
table_name: str,
) -> None: ) -> None:
if relations_key not in self._relations: if relations_key not in self._relations:
self._relations[relations_key] = {"type": "primary"} self._relations[relations_key] = {"type": "primary"}

0
scripts/clean.sh Normal file → Executable file
View File

0
scripts/publish.sh Normal file → Executable file
View File

0
scripts/test.sh Normal file → Executable file
View File

View File

@ -5,6 +5,7 @@ from pydantic import ValidationError
import ormar import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -120,6 +121,7 @@ async def test_model_crud():
track = await Track.objects.get(title="The Bird") track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model)
assert track.album.name is None assert track.album.name is None
await track.album.load() await track.album.load()
assert track.album.name == "Malibu" assert track.album.name == "Malibu"

View File

@ -75,6 +75,15 @@ def test_model_attribute_access(example):
example.test = 12 example.test = 12
assert example.test == 12 assert example.test == 12
example._orm_saved = True
assert example._orm_saved
def test_model_attribute_json_access(example):
example.test_json = dict(aa=12)
assert example.test_json == dict(aa=12)
def test_non_existing_attr(example): def test_non_existing_attr(example):
with pytest.raises(ValueError): with pytest.raises(ValueError):
example.new_attr = 12 example.new_attr = 12

View File

@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
class JsonSample(ormar.Model):
class Meta:
tablename = "jsons"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
test_json: ormar.JSON(nullable=True)
class User(ormar.Model): class User(ormar.Model):
class Meta: class Meta:
tablename = "users" tablename = "users"
@ -56,6 +66,18 @@ def test_model_pk():
assert user.id == 1 assert user.id == 1
@pytest.mark.asyncio
async def test_json_column():
async with database:
await JsonSample.objects.create(test_json=dict(aa=12))
await JsonSample.objects.create(test_json='{"aa": 12}')
items = await JsonSample.objects.all()
assert len(items) == 2
assert items[0].test_json == dict(aa=12)
assert items[1].test_json == dict(aa=12)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database: