WIP super dirty - change to descriptors and different tries

This commit is contained in:
collerek
2021-05-16 20:42:07 +02:00
parent 61a5199986
commit 4c79ce5a5e
13 changed files with 425 additions and 83 deletions

View File

@ -95,6 +95,9 @@ class BaseField(FieldInfo):
self.ormar_default: Any = kwargs.pop("default", None) self.ormar_default: Any = kwargs.pop("default", None)
self.server_default: Any = kwargs.pop("server_default", None) self.server_default: Any = kwargs.pop("server_default", None)
self.represent_as_base64_str: bool = kwargs.pop("represent_as_base64_str", False)
self.use_base64: bool = kwargs.pop("use_base64", False)
for name, value in kwargs.items(): for name, value in kwargs.items():
setattr(self, name, value) setattr(self, name, value)

View File

@ -0,0 +1,32 @@
from pydantic import BaseModel
class OrmarBytes(bytes):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if not isinstance(v, str):
pass
return v
def __get__(self, obj, class_=None):
return 'test'
def __set__(self, obj, value):
obj.__dict__['test'] = value
class ModelA(BaseModel):
test: OrmarBytes = OrmarBytes()
ModelA.test = OrmarBytes()
aa = ModelA(test=b"aa")
print(aa.__dict__)
print(aa.test)
aa.test = 'aas'
print(aa.test)

View File

@ -435,7 +435,12 @@ class LargeBinary(ModelFieldFactory, bytes):
_sample = "bytes" _sample = "bytes"
def __new__( # type: ignore # noqa CFQ002 def __new__( # type: ignore # noqa CFQ002
cls, *, max_length: int, **kwargs: Any cls,
*,
max_length: int,
use_base64: bool = False,
represent_as_base64_str: bool = False,
**kwargs: Any
) -> BaseField: # type: ignore ) -> BaseField: # type: ignore
kwargs = { kwargs = {
**kwargs, **kwargs,

View File

@ -0,0 +1,4 @@
from ormar.models.descriptors.descriptors import PkDescriptor, PropertyDescriptor, \
PydanticDescriptor, \
RelationDescriptor
__all__ = ["PydanticDescriptor", "RelationDescriptor", "PropertyDescriptor", "PkDescriptor"]

View File

@ -0,0 +1,97 @@
import pydantic
from ormar.models.helpers.validation import validate_choices
class PydanticDescriptor:
def __init__(self, name):
self.name = name
def __get__(self, instance, owner):
value = object.__getattribute__(instance, "__dict__").get(self.name, None)
value = object.__getattribute__(instance, "_convert_json")(self.name, value,
"loads")
value = object.__getattribute__(instance, "_convert_bytes")(self.name, value,
"read")
return value
def __set__(self, instance, value):
if self.name in object.__getattribute__(instance, "_choices_fields"):
validate_choices(field=instance.Meta.model_fields[self.name], value=value)
value = object.__getattribute__(instance, '_convert_bytes')(self.name, value,
op="write")
value = object.__getattribute__(instance, '_convert_json')(self.name, value,
op="dumps")
super(instance.__class__, instance).__setattr__(self.name, value)
object.__getattribute__(instance, "set_save_status")(False)
class PkDescriptor:
def __init__(self, name):
self.name = name
def __get__(self, instance, owner):
value = object.__getattribute__(instance, "__dict__").get(self.name, None)
value = object.__getattribute__(instance, "_convert_json")(self.name, value,
"loads")
value = object.__getattribute__(instance, "_convert_bytes")(self.name, value,
"read")
return value
def __set__(self, instance, value):
if self.name in object.__getattribute__(instance, "_choices_fields"):
validate_choices(field=instance.Meta.model_fields[self.name], value=value)
value = object.__getattribute__(instance, '_convert_bytes')(self.name, value,
op="write")
value = object.__getattribute__(instance, '_convert_json')(self.name, value,
op="dumps")
super(instance.__class__, instance).__setattr__(self.name, value)
object.__getattribute__(instance, "set_save_status")(False)
class RelationDescriptor:
def __init__(self, name):
self.name = name
def __get__(self, instance, owner):
if self.name in object.__getattribute__(instance, '_orm'):
return object.__getattribute__(instance, '_orm').get(
self.name) # type: ignore
return None # pragma no cover
def __set__(self, instance, value):
model = (
object.__getattribute__(instance, "Meta")
.model_fields[self.name]
.expand_relationship(value=value, child=instance)
)
if isinstance(object.__getattribute__(instance, "__dict__").get(self.name),
list):
# virtual foreign key or many to many
# TODO: Fix double items in dict, no effect on real action ugly repr
# if model.pk not in [x.pk for x in related_list]:
object.__getattribute__(instance, "__dict__")[self.name].append(model)
else:
# foreign key relation
object.__getattribute__(instance, "__dict__")[self.name] = model
object.__getattribute__(instance, "set_save_status")(False)
class PropertyDescriptor:
def __init__(self, name, function):
self.name = name
self.function = function
def __get__(self, instance, owner):
if instance is None:
return self
if instance is not None and self.function is not None:
bound = self.function.__get__(instance, instance.__class__)
return bound() if callable(bound) else bound
def __set__(self, instance, value):
pass

View File

@ -67,6 +67,11 @@ def populate_default_options_values(
for name, field in new_model.Meta.model_fields.items() for name, field in new_model.Meta.model_fields.items()
if field.__type__ == pydantic.Json if field.__type__ == pydantic.Json
} }
new_model._bytes_fields = {
name
for name, field in new_model.Meta.model_fields.items()
if field.__type__ == bytes
}
class Connection(sqlite3.Connection): class Connection(sqlite3.Connection):

View File

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Type, cast
import ormar import ormar
from ormar import ForeignKey, ManyToMany from ormar import ForeignKey, ManyToMany
from ormar.fields import Through from ormar.fields import Through
from ormar.models.descriptors import RelationDescriptor
from ormar.models.helpers.sqlalchemy import adjust_through_many_to_many_model from ormar.models.helpers.sqlalchemy import adjust_through_many_to_many_model
from ormar.relations import AliasManager from ormar.relations import AliasManager
@ -130,6 +131,8 @@ def register_reverse_model_fields(model_field: "ForeignKeyField") -> None:
orders_by=model_field.related_orders_by, orders_by=model_field.related_orders_by,
skip_field=model_field.skip_reverse, skip_field=model_field.skip_reverse,
) )
if not model_field.skip_reverse:
setattr(model_field.to, related_name, RelationDescriptor(name=related_name))
def register_through_shortcut_fields(model_field: "ManyToManyField") -> None: def register_through_shortcut_fields(model_field: "ManyToManyField") -> None:
@ -160,6 +163,8 @@ def register_through_shortcut_fields(model_field: "ManyToManyField") -> None:
owner=model_field.to, owner=model_field.to,
nullable=True, nullable=True,
) )
setattr(model_field.owner, through_name, RelationDescriptor(name=through_name))
setattr(model_field.to, through_name, RelationDescriptor(name=through_name))
def register_relation_in_alias_manager(field: "ForeignKeyField") -> None: def register_relation_in_alias_manager(field: "ForeignKeyField") -> None:

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union
import sqlalchemy import sqlalchemy
import ormar # noqa: I100, I202 import ormar # noqa: I100, I202
from ormar.models.descriptors import RelationDescriptor
from ormar.models.helpers.pydantic import create_pydantic_field from ormar.models.helpers.pydantic import create_pydantic_field
from ormar.models.helpers.related_names_validation import ( from ormar.models.helpers.related_names_validation import (
validate_related_names_in_relations, validate_related_names_in_relations,
@ -33,6 +34,7 @@ def adjust_through_many_to_many_model(model_field: "ManyToManyField") -> None:
ondelete="CASCADE", ondelete="CASCADE",
owner=model_field.through, owner=model_field.through,
) )
model_fields[child_name] = ormar.ForeignKey( # type: ignore model_fields[child_name] = ormar.ForeignKey( # type: ignore
model_field.owner, model_field.owner,
real_name=child_name, real_name=child_name,
@ -50,6 +52,9 @@ def adjust_through_many_to_many_model(model_field: "ManyToManyField") -> None:
create_pydantic_field(parent_name, model_field.to, model_field) create_pydantic_field(parent_name, model_field.to, model_field)
create_pydantic_field(child_name, model_field.owner, model_field) create_pydantic_field(child_name, model_field.owner, model_field)
setattr(model_field.through, parent_name, RelationDescriptor(name=parent_name))
setattr(model_field.through, child_name, RelationDescriptor(name=child_name))
def create_and_append_m2m_fk( def create_and_append_m2m_fk(
model: Type["Model"], model_field: "ManyToManyField", field_name: str model: Type["Model"], model_field: "ManyToManyField", field_name: str

View File

@ -1,3 +1,4 @@
import base64
import datetime import datetime
import decimal import decimal
import numbers import numbers
@ -77,6 +78,9 @@ def convert_choices_if_needed( # noqa: CCR001
) )
choices = [round(float(o), precision) for o in choices] choices = [round(float(o), precision) for o in choices]
elif field.__type__ == bytes: elif field.__type__ == bytes:
if field.represent_as_base64_str:
value = value if isinstance(value, bytes) else base64.b64decode(value)
else:
value = value if isinstance(value, bytes) else value.encode("utf-8") value = value if isinstance(value, bytes) else value.encode("utf-8")
return value, choices return value, choices

View File

@ -22,6 +22,9 @@ from ormar.exceptions import ModelError
from ormar.fields import BaseField from ormar.fields import BaseField
from ormar.fields.foreign_key import ForeignKeyField from ormar.fields.foreign_key import ForeignKeyField
from ormar.fields.many_to_many import ManyToManyField from ormar.fields.many_to_many import ManyToManyField
from ormar.models.descriptors import PkDescriptor, PropertyDescriptor, \
PydanticDescriptor, \
RelationDescriptor
from ormar.models.helpers import ( from ormar.models.helpers import (
alias_manager, alias_manager,
check_required_meta_parameters, check_required_meta_parameters,
@ -95,6 +98,7 @@ def add_cached_properties(new_model: Type["Model"]) -> None:
new_model._pydantic_fields = {name for name in new_model.__fields__} new_model._pydantic_fields = {name for name in new_model.__fields__}
new_model._choices_fields = set() new_model._choices_fields = set()
new_model._json_fields = set() new_model._json_fields = set()
new_model._bytes_fields = set()
def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001 def add_property_fields(new_model: Type["Model"], attrs: Dict) -> None: # noqa: CCR001
@ -539,8 +543,12 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
populate_meta_sqlalchemy_table_if_required(new_model.Meta) populate_meta_sqlalchemy_table_if_required(new_model.Meta)
expand_reverse_relationships(new_model) expand_reverse_relationships(new_model)
# TODO: iterate only related fields # TODO: iterate only related fields
for field in new_model.Meta.model_fields.values(): for name, field in new_model.Meta.model_fields.items():
register_relation_in_alias_manager(field=field) register_relation_in_alias_manager(field=field)
if field.is_relation:
setattr(new_model, name, RelationDescriptor(name=name))
else:
setattr(new_model, name, PydanticDescriptor(name=name))
if new_model.Meta.pkname not in attrs["__annotations__"]: if new_model.Meta.pkname not in attrs["__annotations__"]:
field_name = new_model.Meta.pkname field_name = new_model.Meta.pkname
@ -551,6 +559,13 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
) )
new_model.Meta.alias_manager = alias_manager new_model.Meta.alias_manager = alias_manager
for item in new_model.Meta.property_fields:
function = getattr(new_model, item)
setattr(new_model, item, PropertyDescriptor(name=item,
function=function))
setattr(new_model, 'pk', PkDescriptor(name=new_model.Meta.pkname))
return new_model return new_model
@property @property
@ -564,6 +579,17 @@ class ModelMetaclass(pydantic.main.ModelMetaclass):
return QuerySet(model_cls=cls) return QuerySet(model_cls=cls)
def __getattr__(self, item: str) -> Any: def __getattr__(self, item: str) -> Any:
"""
Returns FieldAccessors on access to model fields from a class,
that way it can be used in python style filters and order_by.
:param item: name of the field
:type item: str
:return: FieldAccessor for given field
:rtype: FieldAccessor
"""
if item == "pk":
item = self.Meta.pkname
if item in object.__getattribute__(self, "Meta").model_fields: if item in object.__getattribute__(self, "Meta").model_fields:
field = self.Meta.model_fields.get(item) field = self.Meta.model_fields.get(item)
if field.is_relation: if field.is_relation:

View File

@ -1,3 +1,4 @@
import base64
import sys import sys
import warnings import warnings
from typing import ( from typing import (
@ -185,7 +186,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
elif name == "pk": elif name == "pk":
object.__setattr__(self, self.Meta.pkname, value) object.__setattr__(self, self.Meta.pkname, value)
self.set_save_status(False) object.__getattribute__(self, "set_save_status")(False)
elif name in object.__getattribute__(self, "_orm"): elif name in object.__getattribute__(self, "_orm"):
model = ( model = (
object.__getattribute__(self, "Meta") object.__getattribute__(self, "Meta")
@ -200,65 +201,68 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
else: else:
# foreign key relation # foreign key relation
object.__getattribute__(self, "__dict__")[name] = model object.__getattribute__(self, "__dict__")[name] = model
self.set_save_status(False) object.__getattribute__(self, "set_save_status")(False)
else: else:
if name in object.__getattribute__(self, "_choices_fields"): if name in object.__getattribute__(self, "_choices_fields"):
validate_choices(field=self.Meta.model_fields[name], value=value) validate_choices(field=self.Meta.model_fields[name], value=value)
super().__setattr__(name, self._convert_json(name, value, op="dumps")) value = object.__getattribute__(self, '_convert_bytes')(name, value, op="write")
self.set_save_status(False) value = object.__getattribute__(self, '_convert_json')(name, value, op="dumps")
super().__setattr__(name, value)
object.__getattribute__(self, "set_save_status")(False)
def __getattribute__(self, item: str) -> Any: # noqa: CCR001 # def __getattribute__(self, item: str) -> Any: # noqa: CCR001
""" # """
Because we need to overwrite getting the attribute by ormar instead of pydantic # Because we need to overwrite getting the attribute by ormar instead of pydantic
as well as returning related models and not the value stored on the model the # as well as returning related models and not the value stored on the model the
__getattribute__ needs to be used not __getattr__. # __getattribute__ needs to be used not __getattr__.
#
It's used to access all attributes so it can be a big overhead that's why a # It's used to access all attributes so it can be a big overhead that's why a
number of short circuits is used. # number of short circuits is used.
#
To short circuit all checks and expansions the set of attribute names present # To short circuit all checks and expansions the set of attribute names present
on each model is gathered into _quick_access_fields that is looked first and # on each model is gathered into _quick_access_fields that is looked first and
if field is in this set the object setattr is called directly. # if field is in this set the object setattr is called directly.
#
To avoid recursion object's getattribute is used to actually get the attribute # To avoid recursion object's getattribute is used to actually get the attribute
value from the model after the checks. # value from the model after the checks.
#
Even the function calls are constructed with objects functions. # Even the function calls are constructed with objects functions.
#
Parameter "pk" is translated into actual primary key field name. # Parameter "pk" is translated into actual primary key field name.
#
Relations are returned so the actual related model is returned and not current # Relations are returned so the actual related model is returned and not current
model's field. The related models are handled by RelationshipManager exposed # model's field. The related models are handled by RelationshipManager exposed
at _orm param. # at _orm param.
#
Json fields are converted if needed. # Json fields are converted if needed.
#
:param item: name of the attribute to retrieve # :param item: name of the attribute to retrieve
:type item: str # :type item: str
:return: value of the attribute # :return: value of the attribute
:rtype: Any # :rtype: Any
""" # """
if item in object.__getattribute__(self, "_quick_access_fields"): # if item in object.__getattribute__(self, "_quick_access_fields"):
return object.__getattribute__(self, item) # return object.__getattribute__(self, item)
if item == "pk": # # if item == "pk":
return object.__getattribute__(self, "__dict__").get(self.Meta.pkname, None) # # return object.__getattribute__(self, "__dict__").get(self.Meta.pkname, None)
if item in object.__getattribute__(self, "extract_related_names")(): # # if item in object.__getattribute__(self, "extract_related_names")():
return object.__getattribute__( # # return object.__getattribute__(
self, "_extract_related_model_instead_of_field" # # self, "_extract_related_model_instead_of_field"
)(item) # # )(item)
if item in object.__getattribute__(self, "extract_through_names")(): # # if item in object.__getattribute__(self, "extract_through_names")():
return object.__getattribute__( # # return object.__getattribute__(
self, "_extract_related_model_instead_of_field" # # self, "_extract_related_model_instead_of_field"
)(item) # # )(item)
if item in object.__getattribute__(self, "Meta").property_fields: # # if item in object.__getattribute__(self, "Meta").property_fields:
value = object.__getattribute__(self, item) # # value = object.__getattribute__(self, item)
return value() if callable(value) else value # # return value() if callable(value) else value
if item in object.__getattribute__(self, "_pydantic_fields"): # # if item in object.__getattribute__(self, "_pydantic_fields"):
value = object.__getattribute__(self, "__dict__").get(item, None) # # value = object.__getattribute__(self, "__dict__").get(item, None)
value = object.__getattribute__(self, "_convert_json")(item, value, "loads") # # value = object.__getattribute__(self, "_convert_json")(item, value, "loads")
return value # # value = object.__getattribute__(self, "_convert_bytes")(item, value, "read")
# # return value
return object.__getattribute__(self, item) # pragma: no cover #
# return object.__getattribute__(self, item) # pragma: no cover
def _verify_model_can_be_initialized(self) -> None: def _verify_model_can_be_initialized(self) -> None:
""" """
@ -297,6 +301,7 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
property_fields = meta.property_fields property_fields = meta.property_fields
model_fields = meta.model_fields model_fields = meta.model_fields
pydantic_fields = object.__getattribute__(self, "__fields__") pydantic_fields = object.__getattribute__(self, "__fields__")
bytes_fields = object.__getattribute__(self, '_bytes_fields')
# remove property fields # remove property fields
for prop_filed in property_fields: for prop_filed in property_fields:
@ -832,6 +837,39 @@ class NewBaseModel(pydantic.BaseModel, ModelTableProxy, metaclass=ModelMetaclass
setattr(self, key, value) setattr(self, key, value)
return self return self
def _convert_bytes(self, column_name: str, value: Any, op: str) -> Union[str, Dict]:
"""
Converts value to/from json if needed (for Json columns).
:param column_name: name of the field
:type column_name: str
:param value: value fo the field
:type value: Any
:param op: operator on json
:type op: str
:return: converted value if needed, else original value
:rtype: Any
"""
if column_name not in object.__getattribute__(self, "_bytes_fields"):
return value
field = self.Meta.model_fields[column_name]
condition = (
isinstance(value, bytes) if op == "read" else not isinstance(value, bytes)
)
if op == "read" and condition:
if field.use_base64:
value = base64.b64encode(value)
elif field.represent_as_base64_str:
value = base64.b64encode(value).decode()
else:
value = value.decode("utf-8")
elif condition:
if field.use_base64 or field.represent_as_base64_str:
value = base64.b64decode(value)
else:
value = value.encode("utf-8")
return value
def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, Dict]: def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, Dict]:
""" """
Converts value to/from json if needed (for Json columns). Converts value to/from json if needed (for Json columns).

View File

@ -0,0 +1,88 @@
import base64
import json
import os
import uuid
from typing import List
import databases
import pydantic
import pytest
import sqlalchemy
from fastapi import FastAPI
from starlette.testclient import TestClient
import ormar
from tests.settings import DATABASE_URL
app = FastAPI()
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
blob3 = b"\xc3\x28"
blob4 = b"\xf0\x28\x8c\x28"
blob5 = b"\xee"
blob6 = b"\xff"
class BaseMeta(ormar.ModelMeta):
metadata = metadata
database = database
class BinaryThing(ormar.Model):
class Meta(BaseMeta):
tablename = "things"
id: uuid.UUID = ormar.UUID(primary_key=True, default=uuid.uuid4)
name: str = ormar.Text(default="")
bt: bytes = ormar.LargeBinary(
max_length=1000,
choices=[blob3, blob4, blob5, blob6],
represent_as_base64_str=True
)
@app.get("/things", response_model=List[BinaryThing])
async def read_things():
return await BinaryThing.objects.order_by("name").all()
@app.post("/things", response_model=BinaryThing)
async def create_things(thing: BinaryThing):
thing = await thing.save()
return thing
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
def test_read_main():
client = TestClient(app)
with client as client:
response = client.post(
"/things", data=json.dumps({"bt": base64.b64encode(blob3).decode()})
)
print(response.content)
assert response.status_code == 200

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import datetime import datetime
import os
import uuid import uuid
from typing import List from typing import List
@ -37,7 +38,23 @@ class LargeBinarySample(ormar.Model):
database = database database = database
id: int = ormar.Integer(primary_key=True) id: int = ormar.Integer(primary_key=True)
test_binary = ormar.LargeBinary(max_length=100000, choices=[blob, blob2]) test_binary: bytes = ormar.LargeBinary(max_length=100000, choices=[blob, blob2])
blob3 = os.urandom(64)
blob4 = os.urandom(100)
class LargeBinaryStr(ormar.Model):
class Meta:
tablename = "my_str_blobs"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
test_binary: str = ormar.LargeBinary(
max_length=100000, choices=[blob3, blob4], represent_as_base64=True
)
class UUIDSample(ormar.Model): class UUIDSample(ormar.Model):
@ -171,6 +188,19 @@ async def test_binary_column():
assert items[1].test_binary == blob2 assert items[1].test_binary == blob2
@pytest.mark.asyncio
async def test_binary_str_column():
async with database:
async with database.transaction(force_rollback=True):
await LargeBinaryStr.objects.create(test_binary=blob3)
await LargeBinaryStr.objects.create(test_binary=blob4)
items = await LargeBinaryStr.objects.all()
assert len(items) == 2
assert items[0].test_binary == blob3
assert items[1].test_binary == blob4
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_uuid_column(): async def test_uuid_column():
async with database: async with database: