Merge pull request #251 from collerek/for_ref_m2m

Bug fixes
This commit is contained in:
collerek
2021-06-22 13:32:49 +02:00
committed by GitHub
9 changed files with 191 additions and 34 deletions

View File

@ -1,3 +1,11 @@
# 0.10.12
## 🐛 Fixes
* Fix `QuerySet.create` method not using init (if custom provided) [#245](https://github.com/collerek/ormar/issues/245)
* Fix `ForwardRef` `ManyToMany` relation setting wrong pydantic type [#250](https://github.com/collerek/ormar/issues/250)
# 0.10.11 # 0.10.11
## ✨ Features ## ✨ Features

View File

@ -76,7 +76,7 @@ class UndefinedType: # pragma no cover
Undefined = UndefinedType() Undefined = UndefinedType()
__version__ = "0.10.11" __version__ = "0.10.12"
__all__ = [ __all__ = [
"Integer", "Integer",
"BigInteger", "BigInteger",

View File

@ -195,7 +195,7 @@ class BaseField(FieldInfo):
if self.ormar_default is not None if self.ormar_default is not None
else (self.server_default if use_server else None) else (self.server_default if use_server else None)
) )
if callable(default): if callable(default): # pragma: no cover
default = default() default = default()
return default return default

View File

@ -131,7 +131,11 @@ def ManyToMany(
validate_not_allowed_fields(kwargs) validate_not_allowed_fields(kwargs)
if to.__class__ == ForwardRef: if to.__class__ == ForwardRef:
__type__ = to if not nullable else Optional[to] __type__ = (
Union[to, List[to]] # type: ignore
if not nullable
else Optional[Union[to, List[to]]] # type: ignore
)
column_type = None column_type = None
else: else:
__type__, column_type = populate_m2m_params_based_on_to_model( __type__, column_type = populate_m2m_params_based_on_to_model(

View File

@ -113,7 +113,7 @@ class SavePrepareMixin(RelationMixin, AliasMixin):
if field_value is not None: if field_value is not None:
target_field = cls.Meta.model_fields[field] target_field = cls.Meta.model_fields[field]
target_pkname = target_field.to.Meta.pkname target_pkname = target_field.to.Meta.pkname
if isinstance(field_value, ormar.Model): if isinstance(field_value, ormar.Model): # pragma: no cover
pk_value = getattr(field_value, target_pkname) pk_value = getattr(field_value, target_pkname)
if not pk_value: if not pk_value:
raise ModelPersistenceError( raise ModelPersistenceError(

View File

@ -602,7 +602,7 @@ class QuerySet(Generic[T]):
] ]
if _as_dict: if _as_dict:
return result return result
if _flatten and not self._excludable.include_entry_count() == 1: if _flatten and self._excludable.include_entry_count() != 1:
raise QueryDefinitionError( raise QueryDefinitionError(
"You cannot flatten values_list if more than one field is selected!" "You cannot flatten values_list if more than one field is selected!"
) )
@ -1014,35 +1014,8 @@ class QuerySet(Generic[T]):
:return: created model :return: created model
:rtype: Model :rtype: Model
""" """
new_kwargs = dict(**kwargs)
new_kwargs = self.model.prepare_model_to_save(new_kwargs)
expr = self.table.insert()
expr = expr.values(**new_kwargs)
instance = self.model(**kwargs) instance = self.model(**kwargs)
await self.model.Meta.signals.pre_save.send( instance = await instance.save()
sender=self.model, instance=instance
)
pk = await self.database.execute(expr)
pk_name = self.model.get_column_alias(self.model_meta.pkname)
if pk_name not in kwargs and pk_name in new_kwargs:
instance.pk = new_kwargs[self.model_meta.pkname]
if pk and isinstance(pk, self.model.pk_type()):
instance.pk = pk
# refresh server side defaults
if any(
field.server_default is not None
for name, field in self.model.Meta.model_fields.items()
if name not in kwargs
):
instance = await instance.load()
instance.set_save_status(True)
await self.model.Meta.signals.post_save.send(
sender=self.model, instance=instance
)
return instance return instance
async def bulk_create(self, objects: List["T"]) -> None: async def bulk_create(self, objects: List["T"]) -> None:

View File

@ -0,0 +1,109 @@
from typing import List, Optional
import databases
import pytest
import sqlalchemy
from fastapi import FastAPI
from pydantic.schema import ForwardRef
from starlette import status
from starlette.testclient import TestClient
import ormar
app = FastAPI()
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
app.state.database = database
@app.on_event("startup")
async def startup() -> None:
database_ = app.state.database
if not database_.is_connected:
await database_.connect()
@app.on_event("shutdown")
async def shutdown() -> None:
database_ = app.state.database
if database_.is_connected:
await database_.disconnect()
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
CityRef = ForwardRef("City")
CountryRef = ForwardRef("Country")
# models.py
class Country(ormar.Model):
class Meta(BaseMeta):
tablename = "countries"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=128, unique=True)
iso2: str = ormar.String(max_length=3)
iso3: str = ormar.String(max_length=4, unique=True)
population: int = ormar.Integer(maximum=10000000000)
demonym: str = ormar.String(max_length=128)
native_name: str = ormar.String(max_length=128)
capital: Optional[CityRef] = ormar.ForeignKey( # type: ignore
CityRef, related_name="capital_city", nullable=True
)
borders: List[Optional[CountryRef]] = ormar.ManyToMany( # type: ignore
CountryRef, nullable=True, skip_reverse=True
)
class City(ormar.Model):
class Meta(BaseMeta):
tablename = "cities"
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=128)
country: Country = ormar.ForeignKey(
Country, related_name="cities", skip_reverse=True
)
Country.update_forward_refs()
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@app.post("/", response_model=Country, status_code=status.HTTP_201_CREATED)
async def create_country(country: Country): # if this is ormar
result = await country.upsert() # it's already initialized as ormar model
return result
def test_payload():
client = TestClient(app)
with client as client:
payload = {
"name": "Thailand",
"iso2": "TH",
"iso3": "THA",
"population": 23123123,
"demonym": "Thai",
"native_name": "Thailand",
}
resp = client.post("/", json=payload, headers={"application-type": "json"})
print(resp.content)
assert resp.status_code == 201
resp_country = Country(**resp.json())
assert resp_country.name == "Thailand"

View File

@ -0,0 +1,63 @@
import uuid
from typing import ClassVar
import databases
import pytest
import sqlalchemy
from pydantic import root_validator
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class BaseMeta(ormar.ModelMeta):
database = database
metadata = metadata
class Mol(ormar.Model):
# fixed namespace to generate always unique uuid from the smiles
_UUID_NAMESPACE: ClassVar[uuid.UUID] = uuid.UUID(
"12345678-abcd-1234-abcd-123456789abc"
)
class Meta(BaseMeta):
tablename = "mols"
id: str = ormar.UUID(primary_key=True, index=True, uuid_format="hex")
smiles: str = ormar.String(nullable=False, unique=True, max_length=256)
def __init__(self, **kwargs):
# this is required to generate id from smiles in init, if id is not given
if "id" not in kwargs:
kwargs["id"] = self._UUID_NAMESPACE
super().__init__(**kwargs)
@root_validator()
def make_canonical_smiles_and_uuid(cls, values):
values["id"], values["smiles"] = cls.uuid(values["smiles"])
return values
@classmethod
def uuid(cls, smiles):
id_ = uuid.uuid5(cls._UUID_NAMESPACE, smiles)
return id_, smiles
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_json_column():
async with database:
await Mol.objects.create(smiles="Cc1ccccc1")
count = await Mol.objects.count()
assert count == 1

View File

@ -199,7 +199,7 @@ async def test_binary_column():
async def test_binary_str_column(): async def test_binary_str_column():
async with database: async with database:
async with database.transaction(force_rollback=True): async with database.transaction(force_rollback=True):
await LargeBinaryStr.objects.create(test_binary=blob3) await LargeBinaryStr(test_binary=blob3).save()
await LargeBinaryStr.objects.create(test_binary=blob4) await LargeBinaryStr.objects.create(test_binary=blob4)
items = await LargeBinaryStr.objects.all() items = await LargeBinaryStr.objects.all()