diff --git a/.coverage b/.coverage index 2485861..2d8a6bb 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/__init__.py b/ormar/__init__.py index 079c5b8..c3d1ed7 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -13,6 +13,7 @@ from ormar.fields import ( String, Text, Time, + UUID, ) from ormar.models import Model from ormar.queryset import QuerySet @@ -49,4 +50,5 @@ __all__ = [ "QuerySet", "RelationType", "Undefined", + "UUID", ] diff --git a/ormar/fields/__init__.py b/ormar/fields/__init__.py index 493ebab..0035a4f 100644 --- a/ormar/fields/__init__.py +++ b/ormar/fields/__init__.py @@ -13,6 +13,7 @@ from ormar.fields.model_fields import ( String, Text, Time, + UUID, ) __all__ = [ @@ -27,6 +28,7 @@ __all__ = [ "Text", "Float", "Time", + "UUID", "ForeignKey", "ManyToMany", "ManyToManyField", diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index c6627a7..57f24b7 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,11 +1,13 @@ import datetime import decimal +import uuid from typing import Any, Optional, Type import pydantic import sqlalchemy from ormar import ModelDefinitionError # noqa I101 +from ormar.fields import sqlalchemy_uuid from ormar.fields.base import BaseField # noqa I101 @@ -285,3 +287,12 @@ class Decimal(ModelFieldFactory): raise ModelDefinitionError( "Parameters scale and precision are required for field Decimal" ) + + +class UUID(ModelFieldFactory): + _bases = (uuid.UUID, BaseField) + _type = uuid.UUID + + @classmethod + def get_column_type(cls, **kwargs: Any) -> Any: + return sqlalchemy_uuid.UUID() diff --git a/ormar/fields/sqlalchemy_uuid.py b/ormar/fields/sqlalchemy_uuid.py new file mode 100644 index 0000000..fbf1388 --- /dev/null +++ b/ormar/fields/sqlalchemy_uuid.py @@ -0,0 +1,59 @@ +import uuid +from typing import Any, Optional, Union + +from sqlalchemy.dialects.postgresql import UUID as psqlUUID +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.types import CHAR, TypeDecorator + + +class UUID(TypeDecorator): # pragma nocover + """Platform-independent GUID type. + + Uses Postgresql's UUID type, otherwise uses + CHAR(32), to store UUID. + + """ + + impl = CHAR + + def _cast_to_uuid(self, value: Union[str, int, bytes]) -> uuid.UUID: + if not isinstance(value, uuid.UUID): + if isinstance(value, bytes): + ret_value = uuid.UUID(bytes=value) + elif isinstance(value, int): + ret_value = uuid.UUID(int=value) + elif isinstance(value, str): + ret_value = uuid.UUID(value) + else: + ret_value = value + return ret_value + + def load_dialect_impl(self, dialect: DefaultDialect) -> Any: + if dialect.name == "postgresql": + return dialect.type_descriptor(psqlUUID()) + else: + return dialect.type_descriptor(CHAR(32)) + + def process_bind_param( + self, value: Union[str, int, bytes, uuid.UUID, None], dialect: DefaultDialect + ) -> Optional[str]: + if value is None: + return value + elif not isinstance(value, uuid.UUID): + value = self._cast_to_uuid(value) + if dialect.name == "postgresql": + return str(value) + else: + return "%.32x" % value.int + + def process_result_value( + self, value: Optional[str], dialect: DefaultDialect + ) -> Optional[uuid.UUID]: + if value is None: + return value + if dialect.name == "postgresql": + return uuid.UUID(value) + else: + if not isinstance(value, uuid.UUID): + return uuid.UUID(value) + return value diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 1430b10..427e95a 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -261,14 +261,15 @@ class QuerySet: expr = self.table.insert() expr = expr.values(**new_kwargs) - # Execute the insert, and return a new model instance. instance = self.model(**kwargs) pk = await self.database.execute(expr) + pk_name = self.model_meta.pkname if pk_name not in kwargs and pk_name in new_kwargs: instance.pk = new_kwargs[self.model_meta.pkname] if pk and isinstance(pk, self.model.pk_type()): setattr(instance, self.model_meta.pkname, pk) + return instance async def bulk_create(self, objects: List["Model"]) -> None: diff --git a/tests/test_models.py b/tests/test_models.py index 79ca0ed..f33d855 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,6 +6,7 @@ import databases import pydantic import pytest import sqlalchemy +import uuid import ormar from ormar.exceptions import QueryDefinitionError, NoMatch @@ -25,6 +26,16 @@ class JsonSample(ormar.Model): test_json: ormar.JSON(nullable=True) +class UUIDSample(ormar.Model): + class Meta: + tablename = "uuids" + metadata = metadata + database = database + + id: ormar.UUID(primary_key=True, default=uuid.uuid4) + test_text: ormar.Text() + + class User(ormar.Model): class Meta: tablename = "users" @@ -113,6 +124,28 @@ async def test_json_column(): assert items[1].test_json == dict(aa=12) +@pytest.mark.asyncio +async def test_uuid_column(): + async with database: + async with database.transaction(force_rollback=True): + u1 = await UUIDSample.objects.create(test_text="aa") + u2 = await UUIDSample.objects.create(test_text="bb") + + items = await UUIDSample.objects.all() + assert len(items) == 2 + + assert isinstance(items[0].id, uuid.UUID) + assert isinstance(items[1].id, uuid.UUID) + + assert items[0].id in (u1.id, u2.id) + assert items[1].id in (u1.id, u2.id) + + assert items[0].id != items[1].id + + item = await UUIDSample.objects.filter(id=u1.id).get() + assert item.id == u1.id + + @pytest.mark.asyncio async def test_model_crud(): async with database: