diff --git a/.coverage b/.coverage index 8bc7b2b..1ab8b16 100644 Binary files a/.coverage and b/.coverage differ diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 761833e..2e13022 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -88,7 +88,7 @@ def sqlalchemy_columns_from_model_fields( return pkname, columns -def populate_default_field_value( +def populate_default_pydantic_field_value( type_: Type[BaseField], field: str, attrs: dict ) -> dict: def_value = type_.default_value() @@ -105,10 +105,66 @@ def populate_pydantic_default_values(attrs: Dict) -> Dict: if issubclass(type_, BaseField): if type_.name is None: type_.name = field - attrs = populate_default_field_value(type_, field, attrs) + attrs = populate_default_pydantic_field_value(type_, field, attrs) return attrs +def extract_annotations_and_module( + attrs: dict, new_model: "ModelMetaclass", bases: Tuple +) -> dict: + annotations = attrs.get("__annotations__") or new_model.__annotations__ + attrs["__annotations__"] = annotations + attrs = populate_pydantic_default_values(attrs) + + attrs["__module__"] = attrs["__module__"] or bases[0].__module__ + attrs["__annotations__"] = attrs["__annotations__"] or bases[0].__annotations__ + return attrs + + +def populate_meta_orm_model_fields( + attrs: dict, new_model: Type["Model"] +) -> Type["Model"]: + model_fields = { + field_name: field + for field_name, field in attrs["__annotations__"].items() + if issubclass(field, BaseField) + } + new_model.Meta.model_fields = model_fields + return new_model + + +def populate_meta_tablename_columns_and_pk( + name: str, new_model: Type["Model"] +) -> Type["Model"]: + tablename = name.lower() + "s" + new_model.Meta.tablename = new_model.Meta.tablename or tablename + + if hasattr(new_model.Meta, "columns"): + columns = new_model.Meta.table.columns + pkname = new_model.Meta.pkname + else: + pkname, columns = sqlalchemy_columns_from_model_fields( + new_model.Meta.model_fields, new_model.Meta.tablename + ) + new_model.Meta.columns = columns + new_model.Meta.pkname = pkname + + if not new_model.Meta.pkname: + raise ModelDefinitionError("Table has to have a primary key.") + + return new_model + + +def populate_meta_sqlalchemy_table_if_required( + new_model: Type["Model"], +) -> Type["Model"]: + if not hasattr(new_model.Meta, "table"): + new_model.Meta.table = sqlalchemy.Table( + new_model.Meta.tablename, new_model.Meta.metadata, *new_model.Meta.columns + ) + return new_model + + def get_pydantic_base_orm_config() -> Type[BaseConfig]: class Config(BaseConfig): orm_mode = True @@ -127,46 +183,10 @@ class ModelMetaclass(pydantic.main.ModelMetaclass): if hasattr(new_model, "Meta"): - annotations = attrs.get("__annotations__") or new_model.__annotations__ - attrs["__annotations__"] = annotations - attrs = populate_pydantic_default_values(attrs) - - attrs["__module__"] = attrs["__module__"] or bases[0].__module__ - attrs["__annotations__"] = ( - attrs["__annotations__"] or bases[0].__annotations__ - ) - - tablename = name.lower() + "s" - new_model.Meta.tablename = new_model.Meta.tablename or tablename - - # sqlalchemy table creation - - model_fields = { - field_name: field - for field_name, field in attrs["__annotations__"].items() - if issubclass(field, BaseField) - } - - if hasattr(new_model.Meta, "columns"): - columns = new_model.Meta.table.columns - pkname = new_model.Meta.pkname - else: - pkname, columns = sqlalchemy_columns_from_model_fields( - model_fields, new_model.Meta.tablename - ) - - if not hasattr(new_model.Meta, "table"): - new_model.Meta.table = sqlalchemy.Table( - new_model.Meta.tablename, new_model.Meta.metadata, *columns - ) - - new_model.Meta.columns = columns - new_model.Meta.pkname = pkname - - if not pkname: - raise ModelDefinitionError("Table has to have a primary key.") - - new_model.Meta.model_fields = model_fields + attrs = extract_annotations_and_module(attrs, new_model, bases) + new_model = populate_meta_orm_model_fields(attrs, new_model) + new_model = populate_meta_tablename_columns_and_pk(name, new_model) + new_model = populate_meta_sqlalchemy_table_if_required(new_model) expand_reverse_relationships(new_model) new_model = super().__new__( # type: ignore