modify save_related to be able to save whole tree from dict - including reverse fk and m2m relations - with correct order of saving

This commit is contained in:
collerek
2021-04-12 17:39:42 +02:00
parent 6780c9de8a
commit 854b27947a
7 changed files with 474 additions and 58 deletions

View File

@ -119,7 +119,7 @@ async def test_saving_many_to_many():
assert count == 0
count = await hq.save_related(save_all=True)
assert count == 2
assert count == 3
hq.nicks[0].name = "Kabucha"
hq.nicks[1].name = "Kabucha2"

View File

@ -0,0 +1,256 @@
from typing import List
import databases
import pytest
import sqlalchemy
import ormar
from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()
class CringeLevel(ormar.Model):
class Meta:
tablename = "levels"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
class NickName(ormar.Model):
class Meta:
tablename = "nicks"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
is_lame: bool = ormar.Boolean(nullable=True)
level: CringeLevel = ormar.ForeignKey(CringeLevel)
class NicksHq(ormar.Model):
class Meta:
tablename = "nicks_x_hq"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
new_field: str = ormar.String(max_length=200, nullable=True)
class HQ(ormar.Model):
class Meta:
tablename = "hqs"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="hq_name")
nicks: List[NickName] = ormar.ManyToMany(NickName, through=NicksHq)
class Company(ormar.Model):
class Meta:
tablename = "companies"
metadata = metadata
database = database
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100, nullable=False, name="company_name")
founded: int = ormar.Integer(nullable=True)
hq: HQ = ormar.ForeignKey(HQ, related_name="companies")
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.drop_all(engine)
metadata.create_all(engine)
yield
metadata.drop_all(engine)
@pytest.mark.asyncio
async def test_saving_related_reverse_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {"companies": [{"name": "Banzai"}], "name": "Main"}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 2
hq_check = await HQ.objects.select_related("companies").get()
assert hq_check.pk is not None
assert hq_check.name == "Main"
assert len(hq_check.companies) == 1
assert hq_check.companies[0].name == "Banzai"
assert hq_check.companies[0].pk is not None
@pytest.mark.asyncio
async def test_saving_related_reverse_fk_multiple():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"companies": [{"name": "Banzai"}, {"name": "Yamate"}],
"name": "Main",
}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 3
hq_check = await HQ.objects.select_related("companies").get()
assert hq_check.pk is not None
assert hq_check.name == "Main"
assert len(hq_check.companies) == 2
assert hq_check.companies[0].name == "Banzai"
assert hq_check.companies[0].pk is not None
assert hq_check.companies[1].name == "Yamate"
assert hq_check.companies[1].pk is not None
@pytest.mark.asyncio
async def test_saving_related_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {"hq": {"name": "Main"}, "name": "Banzai"}
comp = Company(**payload)
count = await comp.save_related(follow=True, save_all=True)
assert count == 2
comp_check = await Company.objects.select_related("hq").get()
assert comp_check.pk is not None
assert comp_check.name == "Banzai"
assert comp_check.hq.name == "Main"
assert comp_check.hq.pk is not None
@pytest.mark.asyncio
async def test_saving_many_to_many_wo_through():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{"name": "Bazinga0", "is_lame": False},
{"name": "Bazinga20", "is_lame": True},
],
}
hq = HQ(**payload)
count = await hq.save_related()
assert count == 3
hq_check = await HQ.objects.select_related("nicks").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[1].name == "Bazinga20"
@pytest.mark.asyncio
async def test_saving_many_to_many_with_through():
async with database:
async with database.transaction(force_rollback=True):
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{
"name": "Bazinga0",
"is_lame": False,
"nickshq": {"new_field": "test"},
},
{
"name": "Bazinga20",
"is_lame": True,
"nickshq": {"new_field": "test2"},
},
],
}
hq = HQ(**payload)
count = await hq.save_related()
assert count == 3
hq_check = await HQ.objects.select_related("nicks").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[0].nickshq.new_field == "test"
assert hq_check.nicks[1].name == "Bazinga20"
assert hq_check.nicks[1].nickshq.new_field == "test2"
@pytest.mark.asyncio
async def test_saving_nested_with_m2m_and_rev_fk():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"name": "Main",
"nicks": [
{"name": "Bazinga0", "is_lame": False, "level": {"name": "High"}},
{"name": "Bazinga20", "is_lame": True, "level": {"name": "Low"}},
],
}
hq = HQ(**payload)
count = await hq.save_related(follow=True, save_all=True)
assert count == 5
hq_check = await HQ.objects.select_related("nicks__level").get()
assert hq_check.pk is not None
assert len(hq_check.nicks) == 2
assert hq_check.nicks[0].name == "Bazinga0"
assert hq_check.nicks[0].level.name == "High"
assert hq_check.nicks[1].name == "Bazinga20"
assert hq_check.nicks[1].level.name == "Low"
@pytest.mark.asyncio
async def test_saving_nested_with_m2m_and_rev_fk_and_through():
async with database:
async with database.transaction(force_rollback=True):
payload = {
"hq": {
"name": "Yoko",
"nicks": [
{
"name": "Bazinga0",
"is_lame": False,
"nickshq": {"new_field": "test"},
"level": {"name": "High"},
},
{
"name": "Bazinga20",
"is_lame": True,
"nickshq": {"new_field": "test2"},
"level": {"name": "Low"},
},
],
},
"name": "Main",
}
company = Company(**payload)
count = await company.save_related(follow=True, save_all=True)
assert count == 6
company_check = await Company.objects.select_related(
"hq__nicks__level"
).get()
assert company_check.pk is not None
assert company_check.name == "Main"
assert company_check.hq.name == "Yoko"
assert len(company_check.hq.nicks) == 2
assert company_check.hq.nicks[0].name == "Bazinga0"
assert company_check.hq.nicks[0].nickshq.new_field == "test"
assert company_check.hq.nicks[0].level.name == "High"
assert company_check.hq.nicks[1].name == "Bazinga20"
assert company_check.hq.nicks[1].level.name == "Low"
assert company_check.hq.nicks[1].nickshq.new_field == "test2"