some cleanup and tests

This commit is contained in:
collerek
2020-08-23 12:54:58 +02:00
parent 08e251efdb
commit 53384879a9
17 changed files with 370 additions and 396 deletions

BIN
.coverage Normal file

Binary file not shown.

View File

@ -1,6 +1,5 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING from typing import Any, List, Optional, TYPE_CHECKING
import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Field from pydantic import Field
@ -10,13 +9,6 @@ if TYPE_CHECKING: # pragma no cover
from ormar.models import Model from ormar.models import Model
def prepare_validator(type_):
def validate_model_field(value):
return isinstance(value, type_)
return validate_model_field
class BaseField: class BaseField:
__type__ = None __type__ = None
@ -34,13 +26,7 @@ class BaseField:
server_default: Any server_default: Any
@classmethod @classmethod
def is_required(cls) -> bool: def default_value(cls) -> Optional[Field]:
return (
not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key()
)
@classmethod
def default_value(cls):
if cls.is_auto_primary_key(): if cls.is_auto_primary_key():
return Field(default=None) return Field(default=None)
if cls.has_default(): if cls.has_default():
@ -52,7 +38,7 @@ class BaseField:
return None return None
@classmethod @classmethod
def has_default(cls): def has_default(cls) -> bool:
return cls.default is not None or cls.server_default is not None return cls.default is not None or cls.server_default is not None
@classmethod @classmethod

View File

@ -1,27 +0,0 @@
from typing import Any, TYPE_CHECKING, Type
from ormar import ModelDefinitionError
if TYPE_CHECKING: # pragma no cover
from ormar.fields import BaseField
class RequiredParams:
def __init__(self, *args: str) -> None:
self._required = list(args)
def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]:
old_init = model_field_class.__init__
model_field_class._old_init = old_init
def __init__(instance: "BaseField", **kwargs: Any) -> None:
super(instance.__class__, instance).__init__(**kwargs)
for arg in self._required:
if arg not in kwargs:
raise ModelDefinitionError(
f"{instance.__class__.__name__} field requires parameter: {arg}"
)
setattr(instance, arg, kwargs.pop(arg))
model_field_class.__init__ = __init__
return model_field_class

View File

@ -1,7 +1,6 @@
from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable from typing import Any, Callable, List, Optional, TYPE_CHECKING, Type, Union
import sqlalchemy import sqlalchemy
from pydantic import BaseModel
import ormar # noqa I101 import ormar # noqa I101
from ormar.exceptions import RelationshipInstanceError from ormar.exceptions import RelationshipInstanceError
@ -13,8 +12,7 @@ if TYPE_CHECKING: # pragma no cover
def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
init_dict = { init_dict = {
**{fk.Meta.pkname: pk or -1, **{fk.Meta.pkname: pk or -1, "__pk_only__": True},
'__pk_only__': True},
**{ **{
k: create_dummy_instance(v.to) k: create_dummy_instance(v.to)
for k, v in fk.Meta.model_fields.items() for k, v in fk.Meta.model_fields.items()
@ -24,10 +22,15 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model":
return fk(**init_dict) return fk(**init_dict)
def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True, def ForeignKey(
related_name: str = None, to: "Model",
virtual: bool = False, *,
) -> Type[object]: name: str = None,
unique: bool = False,
nullable: bool = True,
related_name: str = None,
virtual: bool = False,
) -> Type[object]:
fk_string = to.Meta.tablename + "." + to.Meta.pkname fk_string = to.Meta.tablename + "." + to.Meta.pkname
to_field = to.__fields__[to.Meta.pkname] to_field = to.__fields__[to.Meta.pkname]
namespace = dict( namespace = dict(
@ -43,7 +46,7 @@ def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = T
index=False, index=False,
pydantic_only=False, pydantic_only=False,
default=None, default=None,
server_default=None server_default=None,
) )
return type("ForeignKey", (ForeignKeyField, BaseField), namespace) return type("ForeignKey", (ForeignKeyField, BaseField), namespace)
@ -59,21 +62,21 @@ class ForeignKeyField(BaseField):
yield cls.validate yield cls.validate
@classmethod @classmethod
def validate(cls, v: Any) -> Any: def validate(cls, value: Any) -> Any:
return v return value
@property # @property
def __type__(self) -> Type[BaseModel]: # def __type__(self) -> Type[BaseModel]:
return self.to.__pydantic_model__ # return self.to.__pydantic_model__
@classmethod # @classmethod
def get_column_type(cls) -> sqlalchemy.Column: # def get_column_type(cls) -> sqlalchemy.Column:
to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] # to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname]
return to_column.column_type # return to_column.column_type
@classmethod @classmethod
def _extract_model_from_sequence( def _extract_model_from_sequence(
cls, value: List, child: "Model" cls, value: List, child: "Model"
) -> Union["Model", List["Model"]]: ) -> Union["Model", List["Model"]]:
return [cls.expand_relationship(val, child) for val in value] return [cls.expand_relationship(val, child) for val in value]
@ -109,7 +112,7 @@ class ForeignKeyField(BaseField):
@classmethod @classmethod
def expand_relationship( def expand_relationship(
cls, value: Any, child: "Model" cls, value: Any, child: "Model"
) -> Optional[Union["Model", List["Model"]]]: ) -> Optional[Union["Model", List["Model"]]]:
if value is None: if value is None:
return None return None

View File

@ -1,41 +1,43 @@
import datetime import datetime
import decimal import decimal
import re import re
from typing import Type, Any, Optional from typing import Any, Optional, Type
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import Json from pydantic import Json
from ormar import ModelDefinitionError from ormar import ModelDefinitionError # noqa I101
from ormar.fields.base import BaseField # noqa I101 from ormar.fields.base import BaseField # noqa I101
def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool: def is_field_nullable(
nullable: Optional[bool], default: Any, server_default: Any
) -> bool:
if nullable is None: if nullable is None:
return default is not None or server_default is not None return default is not None or server_default is not None
return nullable return nullable
def String( def String(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
allow_blank: bool = False, allow_blank: bool = False,
strip_whitespace: bool = False, strip_whitespace: bool = False,
min_length: int = None, min_length: int = None,
max_length: int = None, max_length: int = None,
curtail_length: int = None, curtail_length: int = None,
regex: str = None, regex: str = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[str]: ) -> Type[str]:
if max_length is None or max_length <= 0: if max_length is None or max_length <= 0:
raise ModelDefinitionError(f'Parameter max_length is required for field String') raise ModelDefinitionError("Parameter max_length is required for field String")
namespace = dict( namespace = dict(
__type__=str, __type__=str,
@ -54,26 +56,26 @@ def String(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("String", (pydantic.ConstrainedStr, BaseField), namespace) return type("String", (pydantic.ConstrainedStr, BaseField), namespace)
def Integer( def Integer(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
autoincrement: bool = None, autoincrement: bool = None,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
minimum: int = None, minimum: int = None,
maximum: int = None, maximum: int = None,
multiple_of: int = None, multiple_of: int = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[int]: ) -> Type[int]:
namespace = dict( namespace = dict(
__type__=int, __type__=int,
@ -89,23 +91,23 @@ def Integer(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key autoincrement=autoincrement if autoincrement is not None else primary_key,
) )
return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace)
def Text( def Text(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
allow_blank: bool = False, allow_blank: bool = False,
strip_whitespace: bool = False, strip_whitespace: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[str]: ) -> Type[str]:
namespace = dict( namespace = dict(
__type__=str, __type__=str,
@ -120,25 +122,25 @@ def Text(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) return type("Text", (pydantic.ConstrainedStr, BaseField), namespace)
def Float( def Float(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
minimum: float = None, minimum: float = None,
maximum: float = None, maximum: float = None,
multiple_of: int = None, multiple_of: int = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[int]: ) -> Type[int]:
namespace = dict( namespace = dict(
__type__=float, __type__=float,
@ -154,21 +156,21 @@ def Float(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace)
def Boolean( def Boolean(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[bool]: ) -> Type[bool]:
namespace = dict( namespace = dict(
__type__=bool, __type__=bool,
@ -181,21 +183,21 @@ def Boolean(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Boolean", (int, BaseField), namespace) return type("Boolean", (int, BaseField), namespace)
def DateTime( def DateTime(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[datetime.datetime]: ) -> Type[datetime.datetime]:
namespace = dict( namespace = dict(
__type__=datetime.datetime, __type__=datetime.datetime,
@ -208,21 +210,21 @@ def DateTime(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("DateTime", (datetime.datetime, BaseField), namespace) return type("DateTime", (datetime.datetime, BaseField), namespace)
def Date( def Date(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[datetime.date]: ) -> Type[datetime.date]:
namespace = dict( namespace = dict(
__type__=datetime.date, __type__=datetime.date,
@ -235,21 +237,21 @@ def Date(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Date", (datetime.date, BaseField), namespace) return type("Date", (datetime.date, BaseField), namespace)
def Time( def Time(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[datetime.time]: ) -> Type[datetime.time]:
namespace = dict( namespace = dict(
__type__=datetime.time, __type__=datetime.time,
@ -262,21 +264,21 @@ def Time(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Time", (datetime.time, BaseField), namespace) return type("Time", (datetime.time, BaseField), namespace)
def JSON( def JSON(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[Json]: ) -> Type[Json]:
namespace = dict( namespace = dict(
__type__=pydantic.Json, __type__=pydantic.Json,
@ -289,26 +291,26 @@ def JSON(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("JSON", (pydantic.Json, BaseField), namespace) return type("JSON", (pydantic.Json, BaseField), namespace)
def BigInteger( def BigInteger(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
autoincrement: bool = None, autoincrement: bool = None,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
minimum: int = None, minimum: int = None,
maximum: int = None, maximum: int = None,
multiple_of: int = None, multiple_of: int = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
) -> Type[int]: ) -> Type[int]:
namespace = dict( namespace = dict(
__type__=int, __type__=int,
@ -324,31 +326,33 @@ def BigInteger(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=autoincrement if autoincrement is not None else primary_key autoincrement=autoincrement if autoincrement is not None else primary_key,
) )
return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace)
def Decimal( def Decimal(
*, *,
name: str = None, name: str = None,
primary_key: bool = False, primary_key: bool = False,
nullable: bool = None, nullable: bool = None,
index: bool = False, index: bool = False,
unique: bool = False, unique: bool = False,
minimum: float = None, minimum: float = None,
maximum: float = None, maximum: float = None,
multiple_of: int = None, multiple_of: int = None,
precision: int = None, precision: int = None,
scale: int = None, scale: int = None,
max_digits: int = None, max_digits: int = None,
decimal_places: int = None, decimal_places: int = None,
pydantic_only: bool = False, pydantic_only: bool = False,
default: Any = None, default: Any = None,
server_default: Any = None server_default: Any = None,
): ) -> Type[decimal.Decimal]:
if precision is None or precision < 0 or scale is None or scale < 0: if precision is None or precision < 0 or scale is None or scale < 0:
raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal') raise ModelDefinitionError(
"Parameters scale and precision are required for field Decimal"
)
namespace = dict( namespace = dict(
__type__=decimal.Decimal, __type__=decimal.Decimal,
@ -368,6 +372,6 @@ def Decimal(
pydantic_only=pydantic_only, pydantic_only=pydantic_only,
default=default, default=default,
server_default=server_default, server_default=server_default,
autoincrement=False autoincrement=False,
) )
return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace)

View File

@ -2,16 +2,17 @@ import inspect
import json import json
import uuid import uuid
from typing import ( from typing import (
AbstractSet,
Any, Any,
Callable,
Dict, Dict,
List, List,
Mapping,
Optional, Optional,
Set, Set,
TYPE_CHECKING, TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
Union, AbstractSet, Mapping, Union,
) )
import databases import databases
@ -20,10 +21,9 @@ import sqlalchemy
from pydantic import BaseModel from pydantic import BaseModel
import ormar # noqa I100 import ormar # noqa I100
from ormar import ForeignKey
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.models.metaclass import ModelMetaclass, ModelMeta from ormar.models.metaclass import ModelMeta, ModelMetaclass
from ormar.relations import RelationshipManager from ormar.relations import RelationshipManager
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -39,7 +39,8 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
# FakePydantic inherits from list in order to be treated as # FakePydantic inherits from list in order to be treated as
# request.Body parameter in fastapi routes, # request.Body parameter in fastapi routes,
# inheriting from pydantic.BaseModel causes metaclass conflicts # inheriting from pydantic.BaseModel causes metaclass conflicts
__slots__ = ('_orm_id', '_orm_saved') __slots__ = ("_orm_id", "_orm_saved")
__abstract__ = True
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
__model_fields__: Dict[str, TypeVar[BaseField]] __model_fields__: Dict[str, TypeVar[BaseField]]
@ -63,18 +64,18 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
if "pk" in kwargs: if "pk" in kwargs:
kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs[self.Meta.pkname] = kwargs.pop("pk")
kwargs = { kwargs = {
k: self.Meta.model_fields[k].expand_relationship(v, self) k: self._convert_json(
k, self.Meta.model_fields[k].expand_relationship(v, self), "dumps"
)
for k, v in kwargs.items() for k, v in kwargs.items()
} }
values, fields_set, validation_error = pydantic.validate_model( values, fields_set, validation_error = pydantic.validate_model(self, kwargs)
self, kwargs
)
if validation_error and not pk_only: if validation_error and not pk_only:
raise validation_error raise validation_error
object.__setattr__(self, '__dict__', values) object.__setattr__(self, "__dict__", values)
object.__setattr__(self, '__fields_set__', fields_set) object.__setattr__(self, "__fields_set__", fields_set)
# super().__init__(**kwargs) # super().__init__(**kwargs)
# self.values = self.__pydantic_model__(**kwargs) # self.values = self.__pydantic_model__(**kwargs)
@ -82,58 +83,50 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def __del__(self) -> None: def __del__(self) -> None:
self.Meta._orm_relationship_manager.deregister(self) self.Meta._orm_relationship_manager.deregister(self)
def __setattr__(self, name, value): def __setattr__(self, name: str, value: Any) -> None:
relation_key = self.get_name(title=True) + "_" + name relation_key = self.get_name(title=True) + "_" + name
if name in self.__slots__: if name in self.__slots__:
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == 'pk': elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
elif self.Meta._orm_relationship_manager.contains(relation_key, self): elif self.Meta._orm_relationship_manager.contains(relation_key, self):
self.Meta.model_fields[name].expand_relationship(value, self) self.Meta.model_fields[name].expand_relationship(value, self)
else: else:
super().__setattr__(name, value) value = (
self._convert_json(name, value, "dumps")
if name in self.__fields__
else value
)
super().__setattr__(name, value)
def __getattr__(self, item): def __getattribute__(self, item: str) -> Any:
if item != "__fields__" and item in self.__fields__:
related = self._extract_related_model_instead_of_field(item)
if related:
return related
value = object.__getattribute__(self, item)
value = self._convert_json(item, value, "loads")
return value
return super().__getattribute__(item)
def __getattr__(self, item: str) -> Optional[Union["Model", List["Model"]]]:
return self._extract_related_model_instead_of_field(item)
def _extract_related_model_instead_of_field(
self, item: str
) -> Optional[Union["Model", List["Model"]]]:
relation_key = self.get_name(title=True) + "_" + item relation_key = self.get_name(title=True) + "_" + item
if self.Meta._orm_relationship_manager.contains(relation_key, self): if self.Meta._orm_relationship_manager.contains(relation_key, self):
return self.Meta._orm_relationship_manager.get(relation_key, self) return self.Meta._orm_relationship_manager.get(relation_key, self)
# def __setattr__(self, key: str, value: Any) -> None:
# if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'):
# return setattr(self, key, value)
# # elif key in self._extract_related_names():
# # value = self._convert_json(key, value, op="dumps")
# # value = self.Meta.model_fields[key].expand_relationship(value, self)
# # relation_key = self.get_name(title=True) + "_" + key
# # if not self.Meta._orm_relationship_manager.contains(relation_key, self):
# # setattr(self.values, key, value)
# else:
# super().__setattr__(key, value)
# def __getattribute__(self, key: str) -> Any:
# if key != 'Meta' and key in self.Meta.model_fields:
# relation_key = self.get_name(title=True) + "_" + key
# if self.Meta._orm_relationship_manager.contains(relation_key, self):
# return self.Meta._orm_relationship_manager.get(relation_key, self)
# item = getattr(self.__fields__, key, None)
# item = self._convert_json(key, item, op="loads")
# return item
# return super().__getattribute__(key)
def __same__(self, other: "Model") -> bool: def __same__(self, other: "Model") -> bool:
if self.__class__ != other.__class__: # pragma no cover if self.__class__ != other.__class__: # pragma no cover
return False return False
return (self._orm_id == other._orm_id or return (
self.__dict__ == other.__dict__ or self._orm_id == other._orm_id
(self.pk == other.pk and self.pk is not None or self.__dict__ == other.__dict__
)) or (self.pk == other.pk and self.pk is not None)
)
# def __repr__(self) -> str: # pragma no cover
# return self.values.__repr__()
# @classmethod
# def __get_validators__(cls) -> Callable: # pragma no cover
# yield cls.__pydantic_model__.validate
@classmethod @classmethod
def get_name(cls, title: bool = False, lower: bool = True) -> str: def get_name(cls, title: bool = False, lower: bool = True) -> str:
@ -148,10 +141,6 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def pk(self) -> Any: def pk(self) -> Any:
return getattr(self, self.Meta.pkname) return getattr(self, self.Meta.pkname)
@pk.setter
def pk(self, value: Any) -> None:
setattr(self, self.Meta.pkname, value)
@property @property
def pk_column(self) -> sqlalchemy.Column: def pk_column(self) -> sqlalchemy.Column:
return self.Meta.table.primary_key.columns.values()[0] return self.Meta.table.primary_key.columns.values()[0]
@ -160,36 +149,38 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def pk_type(cls) -> Any: def pk_type(cls) -> Any:
return cls.Meta.model_fields[cls.Meta.pkname].__type__ return cls.Meta.model_fields[cls.Meta.pkname].__type__
def dict( def dict( # noqa A003
self, self,
*, *,
include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False, by_alias: bool = False,
skip_defaults: bool = None, skip_defaults: bool = None,
exclude_unset: bool = False, exclude_unset: bool = False,
exclude_defaults: bool = False, exclude_defaults: bool = False,
exclude_none: bool = False, exclude_none: bool = False,
nested: bool = False nested: bool = False
) -> 'DictStrAny': # noqa: A003' ) -> "DictStrAny": # noqa: A003'
dict_instance = super().dict(include=include, dict_instance = super().dict(
exclude=self._exclude_related_names_not_required(nested), include=include,
by_alias=by_alias, exclude=self._exclude_related_names_not_required(nested),
skip_defaults=skip_defaults, by_alias=by_alias,
exclude_unset=exclude_unset, skip_defaults=skip_defaults,
exclude_defaults=exclude_defaults, exclude_unset=exclude_unset,
exclude_none=exclude_none) exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for field in self._extract_related_names(): for field in self._extract_related_names():
nested_model = getattr(self, field) nested_model = getattr(self, field)
if self.Meta.model_fields[field].virtual and nested: if self.Meta.model_fields[field].virtual and nested:
continue continue
if isinstance(nested_model, list) and not isinstance( if isinstance(nested_model, list) and not isinstance(
nested_model, ormar.Model nested_model, ormar.Model
): ):
dict_instance[field] = [x.dict(nested=True) for x in nested_model] dict_instance[field] = [x.dict(nested=True) for x in nested_model]
elif nested_model is not None: elif nested_model is not None:
dict_instance[field] = nested_model.dict(nested=True) dict_instance[field] = nested_model.dict(nested=True)
return dict_instance return dict_instance
def from_dict(self, value_dict: Dict) -> None: def from_dict(self, value_dict: Dict) -> None:
@ -225,19 +216,21 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def _extract_related_names(cls) -> Set: def _extract_related_names(cls) -> Set:
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass( if inspect.isclass(field) and issubclass(field, ForeignKeyField):
field, ForeignKeyField
):
related_names.add(name) related_names.add(name)
return related_names return related_names
@classmethod @classmethod
def _exclude_related_names_not_required(cls, nested:bool=False) -> Set: def _exclude_related_names_not_required(cls, nested: bool = False) -> Set:
if nested: if nested:
return cls._extract_related_names() return cls._extract_related_names()
related_names = set() related_names = set()
for name, field in cls.Meta.model_fields.items(): for name, field in cls.Meta.model_fields.items():
if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable: if (
inspect.isclass(field)
and issubclass(field, ForeignKeyField)
and field.nullable
):
related_names.add(name) related_names.add(name)
return related_names return related_names
@ -267,7 +260,7 @@ class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass):
def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": def merge_two_instances(cls, one: "Model", other: "Model") -> "Model":
for field in one.Meta.model_fields.keys(): for field in one.Meta.model_fields.keys():
if isinstance(getattr(one, field), list) and not isinstance( if isinstance(getattr(one, field), list) and not isinstance(
getattr(one, field), ormar.Model getattr(one, field), ormar.Model
): ):
setattr(other, field, getattr(one, field) + getattr(other, field)) setattr(other, field, getattr(one, field) + getattr(other, field))
elif isinstance(getattr(one, field), ormar.Model): elif isinstance(getattr(one, field), ormar.Model):

View File

@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
import databases import databases
import pydantic import pydantic
import sqlalchemy import sqlalchemy
from pydantic import BaseConfig, create_model, Extra from pydantic import BaseConfig
from pydantic.fields import ModelField, FieldInfo from pydantic.fields import FieldInfo
from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar import ForeignKey, ModelDefinitionError # noqa I100
from ormar.fields import BaseField from ormar.fields import BaseField
@ -29,23 +29,11 @@ class ModelMeta:
_orm_relationship_manager: RelationshipManager _orm_relationship_manager: RelationshipManager
def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]:
pydantic_fields = {
field_name: (
base_field.__type__,
... if base_field.is_required else base_field.default_value,
)
for field_name, base_field in object_dict.items()
if isinstance(base_field, BaseField)
}
return pydantic_fields
def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None:
child_relation_name = ( child_relation_name = (
field.to.get_name(title=True) field.to.get_name(title=True)
+ "_" + "_"
+ (field.related_name or (name.lower() + "s")) + (field.related_name or (name.lower() + "s"))
) )
reverse_name = child_relation_name reverse_name = child_relation_name
relation_name = name.lower().title() + "_" + field.to.get_name() relation_name = name.lower().title() + "_" + field.to.get_name()
@ -61,34 +49,28 @@ def expand_reverse_relationships(model: Type["Model"]) -> None:
parent_model = model_field.to parent_model = model_field.to
child = model child = model
if ( if (
child_model_name not in parent_model.__fields__ child_model_name not in parent_model.__fields__
and child.get_name() not in parent_model.__fields__ and child.get_name() not in parent_model.__fields__
): ):
register_reverse_model_fields(parent_model, child, child_model_name) register_reverse_model_fields(parent_model, child, child_model_name)
def register_reverse_model_fields( def register_reverse_model_fields(
model: Type["Model"], child: Type["Model"], child_model_name: str model: Type["Model"], child: Type["Model"], child_model_name: str
) -> None: ) -> None:
# model.__fields__[child_model_name] = ModelField(
# name=child_model_name,
# type_=Optional[Union[List[child], child]],
# model_config=child.__config__,
# class_validators=child.__validators__,
# )
model.Meta.model_fields[child_model_name] = ForeignKey( model.Meta.model_fields[child_model_name] = ForeignKey(
child, name=child_model_name, virtual=True child, name=child_model_name, virtual=True
) )
def sqlalchemy_columns_from_model_fields( def sqlalchemy_columns_from_model_fields(
name: str, object_dict: Dict, table_name: str name: str, object_dict: Dict, table_name: str
) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]:
columns = [] columns = []
pkname = None pkname = None
model_fields = { model_fields = {
field_name: field field_name: field
for field_name, field in object_dict['__annotations__'].items() for field_name, field in object_dict["__annotations__"].items()
if issubclass(field, BaseField) if issubclass(field, BaseField)
} }
for field_name, field in model_fields.items(): for field_name, field in model_fields.items():
@ -96,7 +78,7 @@ def sqlalchemy_columns_from_model_fields(
if pkname is not None: if pkname is not None:
raise ModelDefinitionError("Only one primary key column is allowed.") raise ModelDefinitionError("Only one primary key column is allowed.")
if field.pydantic_only: if field.pydantic_only:
raise ModelDefinitionError('Primary key column cannot be pydantic only') raise ModelDefinitionError("Primary key column cannot be pydantic only")
pkname = field_name pkname = field_name
if not field.pydantic_only: if not field.pydantic_only:
columns.append(field.get_column(field_name)) columns.append(field.get_column(field_name))
@ -112,10 +94,10 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict:
if type_.name is None: if type_.name is None:
type_.name = field type_.name = field
def_value = type_.default_value() def_value = type_.default_value()
curr_def_value = attrs.get(field, 'NONE') curr_def_value = attrs.get(field, "NONE")
if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo): if curr_def_value == "NONE" and isinstance(def_value, FieldInfo):
attrs[field] = def_value attrs[field] = def_value
elif curr_def_value == 'NONE' and type_.nullable: elif curr_def_value == "NONE" and type_.nullable:
attrs[field] = FieldInfo(default=None) attrs[field] = FieldInfo(default=None)
return attrs return attrs
@ -132,18 +114,15 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]:
class ModelMetaclass(pydantic.main.ModelMetaclass): class ModelMetaclass(pydantic.main.ModelMetaclass):
def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type:
attrs['Config'] = get_pydantic_base_orm_config() attrs["Config"] = get_pydantic_base_orm_config()
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs
) )
if hasattr(new_model, 'Meta'): if hasattr(new_model, "Meta"):
if attrs.get("__abstract__"):
return new_model
annotations = attrs.get("__annotations__") or new_model.__annotations__ annotations = attrs.get("__annotations__") or new_model.__annotations__
attrs["__annotations__"]= annotations attrs["__annotations__"] = annotations
attrs = populate_pydantic_default_values(attrs) attrs = populate_pydantic_default_values(attrs)
tablename = name.lower() + "s" tablename = name.lower() + "s"
@ -152,18 +131,20 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
# sqlalchemy table creation # sqlalchemy table creation
pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(
name, attrs, new_model.Meta.tablename name, attrs, new_model.Meta.tablename
) )
if hasattr(new_model.Meta, "model_fields") and not pkname: if hasattr(new_model.Meta, "model_fields") and not pkname:
model_fields = new_model.Meta.model_fields model_fields = new_model.Meta.model_fields
for fieldname, field in new_model.Meta.model_fields.items(): for fieldname, field in new_model.Meta.model_fields.items():
if field.primary_key: if field.primary_key:
pkname=fieldname pkname = fieldname
columns = new_model.Meta.table.columns columns = new_model.Meta.table.columns
if not hasattr(new_model.Meta, "table"): if not hasattr(new_model.Meta, "table"):
new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns) new_model.Meta.table = sqlalchemy.Table(
new_model.Meta.tablename, new_model.Meta.metadata, *columns
)
new_model.Meta.columns = columns new_model.Meta.columns = columns
new_model.Meta.pkname = pkname new_model.Meta.pkname = pkname
@ -171,12 +152,6 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
if not pkname: if not pkname:
raise ModelDefinitionError("Table has to have a primary key.") raise ModelDefinitionError("Table has to have a primary key.")
# pydantic model creation
new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs)
new_model.Meta.pydantic_model = create_model(
name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields
)
new_model.Meta.model_fields = model_fields new_model.Meta.model_fields = model_fields
new_model = super().__new__( # type: ignore new_model = super().__new__( # type: ignore
mcs, name, bases, attrs mcs, name, bases, attrs

View File

@ -13,10 +13,10 @@ class Model(FakePydantic):
@classmethod @classmethod
def from_row( def from_row(
cls, cls,
row: sqlalchemy.engine.ResultProxy, row: sqlalchemy.engine.ResultProxy,
select_related: List = None, select_related: List = None,
previous_table: str = None, previous_table: str = None,
) -> "Model": ) -> "Model":
item = {} item = {}
@ -66,8 +66,8 @@ class Model(FakePydantic):
self_fields.pop(self.Meta.pkname) self_fields.pop(self.Meta.pkname)
expr = ( expr = (
self.Meta.table.update() self.Meta.table.update()
.values(**self_fields) .values(**self_fields)
.where(self.pk_column == getattr(self, self.Meta.pkname)) .where(self.pk_column == getattr(self, self.Meta.pkname))
) )
result = await self.Meta.database.execute(expr) result = await self.Meta.database.execute(expr)
return result return result

View File

@ -20,12 +20,12 @@ class JoinParameters(NamedTuple):
class Query: class Query:
def __init__( def __init__(
self, self,
model_cls: Type["Model"], model_cls: Type["Model"],
filter_clauses: List, filter_clauses: List,
select_related: List, select_related: List,
limit_count: int, limit_count: int,
offset: int, offset: int,
) -> None: ) -> None:
self.query_offset = offset self.query_offset = offset
@ -49,15 +49,15 @@ class Query:
self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")]
self.select_from = self.table self.select_from = self.table
for key in self.model_cls.Meta.model_fields: # for key in self.model_cls.Meta.model_fields:
if ( # if (
not self.model_cls.Meta.model_fields[key].nullable # not self.model_cls.Meta.model_fields[key].nullable
and isinstance( # and isinstance(
self.model_cls.Meta.model_fields[key], ForeignKeyField, # self.model_cls.Meta.model_fields[key], ForeignKeyField,
) # )
and key not in self._select_related # and key not in self._select_related
): # ):
self._select_related = [key] + self._select_related # self._select_related = [key] + self._select_related
start_params = JoinParameters( start_params = JoinParameters(
self.model_cls, "", self.table.name, self.model_cls self.model_cls, "", self.table.name, self.model_cls
@ -79,7 +79,7 @@ class Query:
expr = self._apply_expression_modifiers(expr) expr = self._apply_expression_modifiers(expr)
#print(expr.compile(compile_kwargs={"literal_binds": True})) # print(expr.compile(compile_kwargs={"literal_binds": True}))
self._reset_query_parameters() self._reset_query_parameters()
return expr, self._select_related return expr, self._select_related
@ -97,12 +97,12 @@ class Query:
@staticmethod @staticmethod
def _field_is_a_foreign_key_and_no_circular_reference( def _field_is_a_foreign_key_and_no_circular_reference(
field: BaseField, field_name: str, rel_part: str field: Type[BaseField], field_name: str, rel_part: str
) -> bool: ) -> bool:
return issubclass(field, ForeignKeyField) and field_name not in rel_part return issubclass(field, ForeignKeyField) and field_name not in rel_part
def _field_qualifies_to_deeper_search( def _field_qualifies_to_deeper_search(
self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str
) -> bool: ) -> bool:
prev_part_of_related = "__".join(rel_part.split("__")[:-1]) prev_part_of_related = "__".join(rel_part.split("__")[:-1])
partial_match = any( partial_match = any(
@ -112,19 +112,19 @@ class Query:
[x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]
) )
return ( return (
(field.virtual and parent_virtual) (field.virtual and parent_virtual)
or (partial_match and not already_checked) or (partial_match and not already_checked)
) or not nested ) or not nested
def on_clause( def on_clause(
self, previous_alias: str, alias: str, from_clause: str, to_clause: str, self, previous_alias: str, alias: str, from_clause: str, to_clause: str,
) -> text: ) -> text:
left_part = f"{alias}_{to_clause}" left_part = f"{alias}_{to_clause}"
right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}"
return text(f"{left_part}={right_part}") return text(f"{left_part}={right_part}")
def _build_join_parameters( def _build_join_parameters(
self, part: str, join_params: JoinParameters self, part: str, join_params: JoinParameters
) -> JoinParameters: ) -> JoinParameters:
model_cls = join_params.model_cls.Meta.model_fields[part].to model_cls = join_params.model_cls.Meta.model_fields[part].to
to_table = model_cls.Meta.table.name to_table = model_cls.Meta.table.name
@ -138,7 +138,8 @@ class Query:
( (
v v
for k, v in model_cls.Meta.model_fields.items() for k, v in model_cls.Meta.model_fields.items()
if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model if issubclass(v, ForeignKeyField)
and v.to == join_params.prev_model
), ),
None, None,
).name ).name
@ -167,25 +168,28 @@ class Query:
return JoinParameters(prev_model, previous_alias, from_table, model_cls) return JoinParameters(prev_model, previous_alias, from_table, model_cls)
def _extract_auto_required_relations( def _extract_auto_required_relations(
self, self,
prev_model: Type["Model"], prev_model: Type["Model"],
rel_part: str = "", rel_part: str = "",
nested: bool = False, nested: bool = False,
parent_virtual: bool = False, parent_virtual: bool = False,
) -> None: ) -> None:
for field_name, field in prev_model.Meta.model_fields.items(): for field_name, field in prev_model.Meta.model_fields.items():
if self._field_is_a_foreign_key_and_no_circular_reference( if self._field_is_a_foreign_key_and_no_circular_reference(
field, field_name, rel_part field, field_name, rel_part
): ):
rel_part = field_name if not rel_part else rel_part + "__" + field_name rel_part = field_name if not rel_part else rel_part + "__" + field_name
if not field.nullable: if not field.nullable:
if rel_part not in self._select_related: if rel_part not in self._select_related:
new_related = "__".join(rel_part.split("__")[:-1]) if len( new_related = (
rel_part.split("__")) > 1 else rel_part "__".join(rel_part.split("__")[:-1])
if len(rel_part.split("__")) > 1
else rel_part
)
self.auto_related.append(new_related) self.auto_related.append(new_related)
rel_part = "" rel_part = ""
elif self._field_qualifies_to_deeper_search( elif self._field_qualifies_to_deeper_search(
field, parent_virtual, nested, rel_part field, parent_virtual, nested, rel_part
): ):
self._extract_auto_required_relations( self._extract_auto_required_relations(
@ -207,7 +211,7 @@ class Query:
self._select_related = new_joins + self.auto_related self._select_related = new_joins + self.auto_related
def _apply_expression_modifiers( def _apply_expression_modifiers(
self, expr: sqlalchemy.sql.select self, expr: sqlalchemy.sql.select
) -> sqlalchemy.sql.select: ) -> sqlalchemy.sql.select:
if self.filter_clauses: if self.filter_clauses:
if len(self.filter_clauses) == 1: if len(self.filter_clauses) == 1:

View File

@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover
class QuerySet: class QuerySet:
def __init__( def __init__(
self, self,
model_cls: Type["Model"] = None, model_cls: Type["Model"] = None,
filter_clauses: List = None, filter_clauses: List = None,
select_related: List = None, select_related: List = None,
limit_count: int = None, limit_count: int = None,
offset: int = None, offset: int = None,
) -> None: ) -> None:
self.model_cls = model_cls self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses self.filter_clauses = [] if filter_clauses is None else filter_clauses
@ -151,9 +151,9 @@ class QuerySet:
pkname = self.model_cls.Meta.pkname pkname = self.model_cls.Meta.pkname
pk = self.model_cls.Meta.model_fields[pkname] pk = self.model_cls.Meta.model_fields[pkname]
if ( if (
pkname in new_kwargs pkname in new_kwargs
and new_kwargs.get(pkname) is None and new_kwargs.get(pkname) is None
and (pk.nullable or pk.autoincrement) and (pk.nullable or pk.autoincrement)
): ):
del new_kwargs[pkname] del new_kwargs[pkname]

View File

@ -5,7 +5,6 @@ from random import choices
from typing import List, TYPE_CHECKING, Union from typing import List, TYPE_CHECKING, Union
from weakref import proxy from weakref import proxy
from ormar import ForeignKey
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
@ -22,7 +21,11 @@ class RelationshipManager:
self._aliases = dict() self._aliases = dict()
def add_relation_type( def add_relation_type(
self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str self,
relations_key: str,
reverse_key: str,
field: ForeignKeyField,
table_name: str,
) -> None: ) -> None:
if relations_key not in self._relations: if relations_key not in self._relations:
self._relations[relations_key] = {"type": "primary"} self._relations[relations_key] = {"type": "primary"}

0
scripts/clean.sh Normal file → Executable file
View File

0
scripts/publish.sh Normal file → Executable file
View File

0
scripts/test.sh Normal file → Executable file
View File

View File

@ -5,6 +5,7 @@ from pydantic import ValidationError
import ormar import ormar
from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from ormar.fields.foreign_key import ForeignKeyField
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -120,6 +121,7 @@ async def test_model_crud():
track = await Track.objects.get(title="The Bird") track = await Track.objects.get(title="The Bird")
assert track.album.pk == album.pk assert track.album.pk == album.pk
assert isinstance(track.album, ormar.Model)
assert track.album.name is None assert track.album.name is None
await track.album.load() await track.album.load()
assert track.album.name == "Malibu" assert track.album.name == "Malibu"

View File

@ -75,9 +75,18 @@ def test_model_attribute_access(example):
example.test = 12 example.test = 12
assert example.test == 12 assert example.test == 12
example._orm_saved = True
assert example._orm_saved
def test_model_attribute_json_access(example):
example.test_json = dict(aa=12)
assert example.test_json == dict(aa=12)
def test_non_existing_attr(example): def test_non_existing_attr(example):
with pytest.raises(ValueError): with pytest.raises(ValueError):
example.new_attr=12 example.new_attr = 12
def test_primary_key_access_and_setting(example): def test_primary_key_access_and_setting(example):

View File

@ -11,6 +11,16 @@ database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
class JsonSample(ormar.Model):
class Meta:
tablename = "jsons"
metadata = metadata
database = database
id: ormar.Integer(primary_key=True)
test_json: ormar.JSON(nullable=True)
class User(ormar.Model): class User(ormar.Model):
class Meta: class Meta:
tablename = "users" tablename = "users"
@ -56,6 +66,18 @@ def test_model_pk():
assert user.id == 1 assert user.id == 1
@pytest.mark.asyncio
async def test_json_column():
async with database:
await JsonSample.objects.create(test_json=dict(aa=12))
await JsonSample.objects.create(test_json='{"aa": 12}')
items = await JsonSample.objects.all()
assert len(items) == 2
assert items[0].test_json == dict(aa=12)
assert items[1].test_json == dict(aa=12)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_crud(): async def test_model_crud():
async with database: async with database: