import itertools from typing import Optional, List import databases import pydantic import pytest import sqlalchemy import ormar from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() class NickNames(ormar.Model): class Meta: tablename = "nicks" metadata = metadata database = database id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") is_lame: bool = ormar.Boolean(nullable=True) class NicksHq(ormar.Model): class Meta: tablename = "nicks_x_hq" metadata = metadata database = database class HQ(ormar.Model): class Meta: tablename = "hqs" metadata = metadata database = database id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="hq_name") nicks: List[NickNames] = ormar.ManyToMany(NickNames, through=NicksHq) class Company(ormar.Model): class Meta: tablename = "companies" metadata = metadata database = database id: int = ormar.Integer(primary_key=True) name: str = ormar.String(max_length=100, nullable=False, name="company_name") founded: int = ormar.Integer(nullable=True) hq: HQ = ormar.ForeignKey(HQ) class Car(ormar.Model): class Meta: tablename = "cars" metadata = metadata database = database id: int = ormar.Integer(primary_key=True) manufacturer: Optional[Company] = ormar.ForeignKey(Company) name: str = ormar.String(max_length=100) year: int = ormar.Integer(nullable=True) gearbox_type: str = ormar.String(max_length=20, nullable=True) gears: int = ormar.Integer(nullable=True) aircon_type: str = ormar.String(max_length=20, nullable=True) @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.drop_all(engine) metadata.create_all(engine) yield metadata.drop_all(engine) @pytest.mark.asyncio async def test_selecting_subset(): async with database: async with database.transaction(force_rollback=True): nick1 = await NickNames.objects.create(name="Nippon", is_lame=False) nick2 = await NickNames.objects.create(name="EroCherry", is_lame=True) hq = await HQ.objects.create(name="Japan") await hq.nicks.add(nick1) await hq.nicks.add(nick2) toyota = await Company.objects.create(name="Toyota", founded=1937, hq=hq) await Car.objects.create( manufacturer=toyota, name="Corolla", year=2020, gearbox_type="Manual", gears=5, aircon_type="Manual", ) await Car.objects.create( manufacturer=toyota, name="Yaris", year=2019, gearbox_type="Manual", gears=5, aircon_type="Manual", ) await Car.objects.create( manufacturer=toyota, name="Supreme", year=2020, gearbox_type="Auto", gears=6, aircon_type="Auto", ) all_cars = ( await Car.objects.select_related( ["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"] ) .fields( [ "id", "name", "manufacturer__name", "manufacturer__hq__name", "manufacturer__hq__nicks__name", ] ) .all() ) all_cars2 = ( await Car.objects.select_related( ["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"] ) .fields( { "id": ..., "name": ..., "manufacturer": { "name": ..., "hq": {"name": ..., "nicks": {"name": ...}}, }, } ) .all() ) all_cars3 = ( await Car.objects.select_related( ["manufacturer", "manufacturer__hq", "manufacturer__hq__nicks"] ) .fields( { "id": ..., "name": ..., "manufacturer": { "name": ..., "hq": {"name": ..., "nicks": {"name"}}, }, } ) .all() ) assert all_cars3 == all_cars for car in itertools.chain(all_cars, all_cars2): assert all( getattr(car, x) is None for x in ["year", "gearbox_type", "gears", "aircon_type"] ) assert car.manufacturer.name == "Toyota" assert car.manufacturer.founded is None assert car.manufacturer.hq.name == "Japan" assert len(car.manufacturer.hq.nicks) == 2 assert car.manufacturer.hq.nicks[0].is_lame is None all_cars = ( await Car.objects.select_related("manufacturer") .fields("id") .fields(["name"]) .all() ) for car in all_cars: assert all( getattr(car, x) is None for x in ["year", "gearbox_type", "gears", "aircon_type"] ) assert car.manufacturer.name == "Toyota" assert car.manufacturer.founded == 1937 assert car.manufacturer.hq.name is None all_cars_check = await Car.objects.select_related("manufacturer").all() for car in all_cars_check: assert all( getattr(car, x) is not None for x in ["year", "gearbox_type", "gears", "aircon_type"] ) assert car.manufacturer.name == "Toyota" assert car.manufacturer.founded == 1937 with pytest.raises(pydantic.error_wrappers.ValidationError): # cannot exclude mandatory model columns - company__name in this example await Car.objects.select_related("manufacturer").fields( ["id", "name", "manufacturer__founded"] ).all()