refactor fields into a package

This commit is contained in:
collerek
2020-08-11 17:34:19 +02:00
parent 704e83fed0
commit 867fc691f7
13 changed files with 335 additions and 290 deletions

BIN
.coverage

Binary file not shown.

View File

@ -6,13 +6,13 @@ from orm.fields import (
DateTime, DateTime,
Decimal, Decimal,
Float, Float,
ForeignKey,
Integer, Integer,
JSON, JSON,
String, String,
Text, Text,
Time, Time,
) )
from orm.fields.foreign_key import ForeignKey
from orm.models import Model from orm.models import Model
__version__ = "0.0.1" __version__ = "0.0.1"
@ -28,7 +28,6 @@ __all__ = [
"Date", "Date",
"Decimal", "Decimal",
"Float", "Float",
"ForeignKey",
"Model", "Model",
"ModelDefinitionError", "ModelDefinitionError",
"ModelNotSet", "ModelNotSet",

View File

@ -1,275 +0,0 @@
import datetime
import decimal
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, Union
import orm
from orm.exceptions import ModelDefinitionError, RelationshipInstanceError
from pydantic import BaseModel, Json
import sqlalchemy
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
class RequiredParams:
def __init__(self, *args: str) -> None:
self._required = list(args)
def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]:
old_init = model_field_class.__init__
model_field_class._old_init = old_init
def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None:
super(instance.__class__, instance).__init__(*args, **kwargs)
for arg in self._required:
if arg not in kwargs:
raise ModelDefinitionError(
f"{instance.__class__.__name__} field requires parameter: {arg}"
)
setattr(instance, arg, kwargs.pop(arg))
model_field_class.__init__ = __init__
return model_field_class
class BaseField:
__type__ = None
def __init__(self, *args: Any, **kwargs: Any) -> None:
name = kwargs.pop("name", None)
args = list(args)
if args:
if isinstance(args[0], str):
if name is not None:
raise ModelDefinitionError(
"Column name cannot be passed positionally and as a keyword."
)
name = args.pop(0)
self.name = name
self._populate_from_kwargs(kwargs)
def _populate_from_kwargs(self, kwargs: Dict) -> None:
self.primary_key = kwargs.pop("primary_key", False)
self.autoincrement = kwargs.pop(
"autoincrement", self.primary_key and self.__type__ == int
)
self.nullable = kwargs.pop("nullable", not self.primary_key)
self.default = kwargs.pop("default", None)
self.server_default = kwargs.pop("server_default", None)
self.index = kwargs.pop("index", None)
self.unique = kwargs.pop("unique", None)
self.pydantic_only = kwargs.pop("pydantic_only", False)
if self.pydantic_only and self.primary_key:
raise ModelDefinitionError("Primary key column cannot be pydantic only.")
@property
def is_required(self) -> bool:
return (
not self.nullable and not self.has_default and not self.is_auto_primary_key
)
@property
def default_value(self) -> Any:
default = self.default if self.default is not None else self.server_default
return default() if callable(default) else default
@property
def has_default(self) -> bool:
return self.default is not None or self.server_default is not None
@property
def is_auto_primary_key(self) -> bool:
if self.primary_key:
return self.autoincrement
return False
def get_column(self, name: str = None) -> sqlalchemy.Column:
self.name = self.name or name
constraints = self.get_constraints()
return sqlalchemy.Column(
self.name,
self.get_column_type(),
*constraints,
primary_key=self.primary_key,
autoincrement=self.autoincrement,
nullable=self.nullable,
index=self.index,
unique=self.unique,
default=self.default,
server_default=self.server_default,
)
def get_column_type(self) -> sqlalchemy.types.TypeEngine:
raise NotImplementedError() # pragma: no cover
def get_constraints(self) -> Optional[List]:
return []
def expand_relationship(self, value: Any, child: "Model") -> Any:
return value
@RequiredParams("length")
class String(BaseField):
__type__ = str
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.String(self.length)
class Integer(BaseField):
__type__ = int
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Integer()
class Text(BaseField):
__type__ = str
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Text()
class Float(BaseField):
__type__ = float
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Float()
class Boolean(BaseField):
__type__ = bool
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Boolean()
class DateTime(BaseField):
__type__ = datetime.datetime
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DateTime()
class Date(BaseField):
__type__ = datetime.date
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Date()
class Time(BaseField):
__type__ = datetime.time
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Time()
class JSON(BaseField):
__type__ = Json
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.JSON()
class BigInteger(BaseField):
__type__ = int
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.BigInteger()
@RequiredParams("length", "precision")
class Decimal(BaseField):
__type__ = decimal.Decimal
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DECIMAL(self.length, self.precision)
def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model":
init_dict = {fk.__pkname__: pk or -1}
init_dict = {
**init_dict,
**{
k: create_dummy_instance(v.to)
for k, v in fk.__model_fields__.items()
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual
},
}
return fk(**init_dict)
class ForeignKey(BaseField):
def __init__(
self,
to: Type["Model"],
name: str = None,
related_name: str = None,
nullable: bool = True,
virtual: bool = False,
) -> None:
super().__init__(nullable=nullable, name=name)
self.virtual = virtual
self.related_name = related_name
self.to = to
@property
def __type__(self) -> Type[BaseModel]:
return self.to.__pydantic_model__
def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]:
fk_string = self.to.__tablename__ + "." + self.to.__pkname__
return [sqlalchemy.schema.ForeignKey(fk_string)]
def get_column_type(self) -> sqlalchemy.Column:
to_column = self.to.__model_fields__[self.to.__pkname__]
return to_column.get_column_type()
def expand_relationship(
self, value: Any, child: "Model"
) -> Union["Model", List["Model"]]:
if isinstance(value, orm.models.Model) and not isinstance(value, self.to):
raise RelationshipInstanceError(
f"Relationship error - expecting: {self.to.__name__}, "
f"but {value.__class__.__name__} encountered."
)
if isinstance(value, list) and not isinstance(value, self.to):
model = [self.expand_relationship(val, child) for val in value]
return model
if isinstance(value, self.to):
model = value
elif isinstance(value, dict):
model = self.to(**value)
else:
if not isinstance(value, self.to.pk_type()):
raise RelationshipInstanceError(
f"Relationship error - ForeignKey {self.to.__name__} "
f"is of type {self.to.pk_type()} "
f"of type {self.__type__} "
f"while {type(value)} passed as a parameter."
)
model = create_dummy_instance(fk=self.to, pk=value)
self.add_to_relationship_registry(model, child)
return model
def add_to_relationship_registry(self, model: "Model", child: "Model") -> None:
model._orm_relationship_manager.add_relation(
model.__class__.__name__.lower(),
child.__class__.__name__.lower(),
model,
child,
virtual=self.virtual,
)

31
orm/fields/__init__.py Normal file
View File

@ -0,0 +1,31 @@
from orm.fields.model_fields import (
BigInteger,
Boolean,
Date,
DateTime,
Decimal,
String,
Integer,
Text,
Float,
Time,
JSON,
)
from orm.fields.foreign_key import ForeignKey
from orm.fields.base import BaseField
__all__ = [
"Decimal",
"BigInteger",
"Boolean",
"Date",
"DateTime",
"String",
"JSON",
"Integer",
"Text",
"Float",
"Time",
"ForeignKey",
"BaseField",
]

107
orm/fields/base.py Normal file
View File

@ -0,0 +1,107 @@
from typing import Type, Any, Dict, Optional, List
import sqlalchemy
from orm import ModelDefinitionError
class RequiredParams:
def __init__(self, *args: str) -> None:
self._required = list(args)
def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]:
old_init = model_field_class.__init__
model_field_class._old_init = old_init
def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None:
super(instance.__class__, instance).__init__(*args, **kwargs)
for arg in self._required:
if arg not in kwargs:
raise ModelDefinitionError(
f"{instance.__class__.__name__} field requires parameter: {arg}"
)
setattr(instance, arg, kwargs.pop(arg))
model_field_class.__init__ = __init__
return model_field_class
class BaseField:
__type__ = None
def __init__(self, *args: Any, **kwargs: Any) -> None:
name = kwargs.pop("name", None)
args = list(args)
if args:
if isinstance(args[0], str):
if name is not None:
raise ModelDefinitionError(
"Column name cannot be passed positionally and as a keyword."
)
name = args.pop(0)
self.name = name
self._populate_from_kwargs(kwargs)
def _populate_from_kwargs(self, kwargs: Dict) -> None:
self.primary_key = kwargs.pop("primary_key", False)
self.autoincrement = kwargs.pop(
"autoincrement", self.primary_key and self.__type__ == int
)
self.nullable = kwargs.pop("nullable", not self.primary_key)
self.default = kwargs.pop("default", None)
self.server_default = kwargs.pop("server_default", None)
self.index = kwargs.pop("index", None)
self.unique = kwargs.pop("unique", None)
self.pydantic_only = kwargs.pop("pydantic_only", False)
if self.pydantic_only and self.primary_key:
raise ModelDefinitionError("Primary key column cannot be pydantic only.")
@property
def is_required(self) -> bool:
return (
not self.nullable and not self.has_default and not self.is_auto_primary_key
)
@property
def default_value(self) -> Any:
default = self.default if self.default is not None else self.server_default
return default() if callable(default) else default
@property
def has_default(self) -> bool:
return self.default is not None or self.server_default is not None
@property
def is_auto_primary_key(self) -> bool:
if self.primary_key:
return self.autoincrement
return False
def get_column(self, name: str = None) -> sqlalchemy.Column:
self.name = self.name or name
constraints = self.get_constraints()
return sqlalchemy.Column(
self.name,
self.get_column_type(),
*constraints,
primary_key=self.primary_key,
autoincrement=self.autoincrement,
nullable=self.nullable,
index=self.index,
unique=self.unique,
default=self.default,
server_default=self.server_default,
)
def get_column_type(self) -> sqlalchemy.types.TypeEngine:
raise NotImplementedError() # pragma: no cover
def get_constraints(self) -> Optional[List]:
return []
def expand_relationship(self, value: Any, child: "Model") -> Any:
return value

91
orm/fields/foreign_key.py Normal file
View File

@ -0,0 +1,91 @@
from typing import Type, List, Any, Union, TYPE_CHECKING
import sqlalchemy
from pydantic import BaseModel
import orm
from orm.exceptions import RelationshipInstanceError
from orm.fields.base import BaseField
if TYPE_CHECKING: # pragma no cover
from orm.models import Model
def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model":
init_dict = {fk.__pkname__: pk or -1}
init_dict = {
**init_dict,
**{
k: create_dummy_instance(v.to)
for k, v in fk.__model_fields__.items()
if isinstance(v, ForeignKey) and not v.nullable and not v.virtual
},
}
return fk(**init_dict)
class ForeignKey(BaseField):
def __init__(
self,
to: Type["Model"],
name: str = None,
related_name: str = None,
nullable: bool = True,
virtual: bool = False,
) -> None:
super().__init__(nullable=nullable, name=name)
self.virtual = virtual
self.related_name = related_name
self.to = to
@property
def __type__(self) -> Type[BaseModel]:
return self.to.__pydantic_model__
def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]:
fk_string = self.to.__tablename__ + "." + self.to.__pkname__
return [sqlalchemy.schema.ForeignKey(fk_string)]
def get_column_type(self) -> sqlalchemy.Column:
to_column = self.to.__model_fields__[self.to.__pkname__]
return to_column.get_column_type()
def expand_relationship(
self, value: Any, child: "Model"
) -> Union["Model", List["Model"]]:
if isinstance(value, orm.models.Model) and not isinstance(value, self.to):
raise RelationshipInstanceError(
f"Relationship error - expecting: {self.to.__name__}, "
f"but {value.__class__.__name__} encountered."
)
if isinstance(value, list) and not isinstance(value, self.to):
model = [self.expand_relationship(val, child) for val in value]
return model
if isinstance(value, self.to):
model = value
elif isinstance(value, dict):
model = self.to(**value)
else:
if not isinstance(value, self.to.pk_type()):
raise RelationshipInstanceError(
f"Relationship error - ForeignKey {self.to.__name__} "
f"is of type {self.to.pk_type()} "
f"while {type(value)} passed as a parameter."
)
model = create_dummy_instance(fk=self.to, pk=value)
self.add_to_relationship_registry(model, child)
return model
def add_to_relationship_registry(self, model: "Model", child: "Model") -> None:
model._orm_relationship_manager.add_relation(
model.__class__.__name__.lower(),
child.__class__.__name__.lower(),
model,
child,
virtual=self.virtual,
)

View File

@ -0,0 +1,86 @@
import datetime
import decimal
import sqlalchemy
from pydantic import Json
from orm.fields.base import BaseField, RequiredParams
@RequiredParams("length")
class String(BaseField):
__type__ = str
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.String(self.length)
class Integer(BaseField):
__type__ = int
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Integer()
class Text(BaseField):
__type__ = str
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Text()
class Float(BaseField):
__type__ = float
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Float()
class Boolean(BaseField):
__type__ = bool
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Boolean()
class DateTime(BaseField):
__type__ = datetime.datetime
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DateTime()
class Date(BaseField):
__type__ = datetime.date
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Date()
class Time(BaseField):
__type__ = datetime.time
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.Time()
class JSON(BaseField):
__type__ = Json
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.JSON()
class BigInteger(BaseField):
__type__ = int
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.BigInteger()
@RequiredParams("length", "precision")
class Decimal(BaseField):
__type__ = decimal.Decimal
def get_column_type(self) -> sqlalchemy.Column:
return sqlalchemy.DECIMAL(self.length, self.precision)

View File

@ -9,7 +9,8 @@ import databases
import orm.queryset as qry import orm.queryset as qry
from orm.exceptions import ModelDefinitionError from orm.exceptions import ModelDefinitionError
from orm.fields import BaseField, ForeignKey from orm import ForeignKey
from orm.fields.base import BaseField
from orm.relations import RelationshipManager from orm.relations import RelationshipManager
import pydantic import pydantic

View File

@ -13,9 +13,10 @@ from typing import (
import databases import databases
import orm import orm
import orm.fields.foreign_key
from orm import ForeignKey from orm import ForeignKey
from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError
from orm.fields import BaseField from orm.fields.base import BaseField
import sqlalchemy import sqlalchemy
from sqlalchemy import text from sqlalchemy import text
@ -79,7 +80,8 @@ class Query:
if ( if (
not self.model_cls.__model_fields__[key].nullable not self.model_cls.__model_fields__[key].nullable
and isinstance( and isinstance(
self.model_cls.__model_fields__[key], orm.fields.ForeignKey self.model_cls.__model_fields__[key],
orm.fields.foreign_key.ForeignKey,
) )
and key not in self._select_related and key not in self._select_related
): ):

View File

@ -5,7 +5,7 @@ from random import choices
from typing import Dict, List, TYPE_CHECKING, Union from typing import Dict, List, TYPE_CHECKING, Union
from weakref import proxy from weakref import proxy
from orm.fields import ForeignKey from orm import ForeignKey
if TYPE_CHECKING: # pragma no cover if TYPE_CHECKING: # pragma no cover
from orm.models import FakePydantic, Model from orm.models import FakePydantic, Model

View File

@ -4,6 +4,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import orm import orm
import orm.fields.foreign_key
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
app = FastAPI() app = FastAPI()
@ -28,7 +29,7 @@ class Item(orm.Model):
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
name = orm.String(length=100) name = orm.String(length=100)
category = orm.ForeignKey(Category, nullable=True) category = orm.fields.foreign_key.ForeignKey(Category, nullable=True)
@app.post("/items/", response_model=Item) @app.post("/items/", response_model=Item)

View File

@ -3,6 +3,7 @@ import pytest
import sqlalchemy import sqlalchemy
import orm import orm
import orm.fields.foreign_key
from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
@ -25,7 +26,7 @@ class Track(orm.Model):
__database__ = database __database__ = database
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
album = orm.ForeignKey(Album) album = orm.fields.foreign_key.ForeignKey(Album)
title = orm.String(length=100) title = orm.String(length=100)
position = orm.Integer() position = orm.Integer()
@ -45,7 +46,7 @@ class Team(orm.Model):
__database__ = database __database__ = database
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
org = orm.ForeignKey(Organisation) org = orm.fields.foreign_key.ForeignKey(Organisation)
name = orm.String(length=100) name = orm.String(length=100)
@ -55,7 +56,7 @@ class Member(orm.Model):
__database__ = database __database__ = database
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
team = orm.ForeignKey(Team) team = orm.fields.foreign_key.ForeignKey(Team)
email = orm.String(length=100) email = orm.String(length=100)

View File

@ -5,6 +5,7 @@ import pytest
import sqlalchemy import sqlalchemy
import orm import orm
import orm.fields.foreign_key
from tests.settings import DATABASE_URL from tests.settings import DATABASE_URL
database = databases.Database(DATABASE_URL, force_rollback=True) database = databases.Database(DATABASE_URL, force_rollback=True)
@ -27,7 +28,7 @@ class SchoolClass(orm.Model):
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
name = orm.String(length=100) name = orm.String(length=100)
department = orm.ForeignKey(Department, nullable=False) department = orm.fields.foreign_key.ForeignKey(Department, nullable=False)
class Category(orm.Model): class Category(orm.Model):
@ -46,8 +47,8 @@ class Student(orm.Model):
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
name = orm.String(length=100) name = orm.String(length=100)
schoolclass = orm.ForeignKey(SchoolClass) schoolclass = orm.fields.foreign_key.ForeignKey(SchoolClass)
category = orm.ForeignKey(Category, nullable=True) category = orm.fields.foreign_key.ForeignKey(Category, nullable=True)
class Teacher(orm.Model): class Teacher(orm.Model):
@ -57,8 +58,8 @@ class Teacher(orm.Model):
id = orm.Integer(primary_key=True) id = orm.Integer(primary_key=True)
name = orm.String(length=100) name = orm.String(length=100)
schoolclass = orm.ForeignKey(SchoolClass) schoolclass = orm.fields.foreign_key.ForeignKey(SchoolClass)
category = orm.ForeignKey(Category, nullable=True) category = orm.fields.foreign_key.ForeignKey(Category, nullable=True)
@pytest.fixture(scope='module') @pytest.fixture(scope='module')