Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,38 +112,76 @@ model_instance_3 = curried_model_3(x="1", y=2)(z=("3", 4))
print(model_instance_3)
```

#### init_model_from_kwargs
#### FlatInitModel

The `init_model_from_kwargs` constructor allows to initialize (potentially nested) models from (flat) kwargs.
The `FlatInitModel` constructor allows to instantiate a potentially deeply nested Pydantic model from flat kwargs.

```python
class SimpleModel(BaseModel):
x: int
y: int = 3
from lupl import FlatInitModel
from pydantic import BaseModel

class DeeplyNestedModel(BaseModel):
z: int

class NestedModel(BaseModel):
a: str
b: SimpleModel
y: int
deeply_nested: DeeplyNestedModel

class Model(BaseModel):
x: int
nested: NestedModel

constructor = FlatInitModel(model=Model)

instance: Model = constructor(x=1, y=2, z=3)
instance.model_dump() # {'x': 1, 'nested': {'y': 2, 'deeply_nested': {'z': 3}}}
```


`FlatInitModel` also handles model union types by processing the first model type of the union.

A common use case for model union types is e.g. to assign a default value to a model union typed field in case a nested model instance does not meet certain criteria, i.e. fails a predicate.

The `model_bool` parameter in `lupl.ConfigDict` allows to specify the condition for *model truthiness* - if the existential condition of a model is met, the model instance gets assigned to the model field, else the constructor falls back to the default value.


The default condition for model truthiness is that *any* model field must be truthy for the model to be considered truthy.

The `model_bool` parameter takes either

class ComplexModel(BaseModel):
p: str
q: NestedModel
- a callable object of arity 1 that receives the model instance at runtime,
- a `str` denoting a field of the model that must be truthy in order for the model to be truthy
- a `set[str]` denoting fields of the model, all of which must be truthy for the model to be truthy.


# p='p value' q=NestedModel(a='a value', b=SimpleModel(x=1, y=2))
model_instance_1 = init_model_from_kwargs(
ComplexModel, x=1, y=2, a="a value", p="p value"
)
The following example defines the truth condition for `DeeplyNestedModel` to be `gt3`. `NestedModel` defines a model union type with a default value - if the `model_bool` predicate fails, the constructor falls back to the default:

# p='p value' q=NestedModel(a='a value', b=SimpleModel(x=1, y=3))
model_instance_2 = init_model_from_kwargs(
ComplexModel, p="p value", q=NestedModel(a="a value", b=SimpleModel(x=1))
)
```python
from lupl import ConfigDict, FlatInitModel
from pydantic import BaseModel

class DeeplyNestedModel(BaseModel):
model_config = ConfigDict(model_bool=lambda model: model.z > 3)

# p='p value' q=NestedModel(a='a value', b=SimpleModel(x=1, y=3))
model_instance_3 = init_model_from_kwargs(
ComplexModel, p="p value", q=init_model_from_kwargs(NestedModel, a="a value", x=1)
)
z: int

class NestedModel(BaseModel):
y: int
deeply_nested: DeeplyNestedModel | str = "default"

class Model(BaseModel):
x: int
nested: NestedModel

constructor = FlatInitModel(model=Model)

instance: Model = constructor(x=1, y=2, z=3)
instance.model_dump() # {'x': 1, 'nested': {'y': 2, 'deeply_nested': 'default'}}
```

If the existential condition of the model is met, the model instance gets assigned:

```python
instance: Model = constructor(x=1, y=2, z=4)
instance.model_dump() # {'x': 1, 'nested': {'y': 2, 'deeply_nested': {'z': 4}}}
```
3 changes: 2 additions & 1 deletion lupl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lupl.compose_router import ComposeRouter
from lupl.ichunk import ichunk
from lupl.pydantic_tools.curry_model import CurryModel, validate_model_field
from lupl.pydantic_tools.model_constructors import init_model_from_kwargs
from lupl.pydantic_tools.flat_init.flat_init_model import FlatInitModel
from lupl.pydantic_tools.flat_init.utils import ConfigDict
from lupl.pydantic_tools.mutual_constraint_validator import _MutualConstraintMixin
66 changes: 66 additions & 0 deletions lupl/pydantic_tools/flat_init/flat_init_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import types
from typing import Any, TypeGuard, get_args, get_origin
import typing

from lupl import CurryModel
from lupl.pydantic_tools.flat_init.utils import (
ModelBoolPredicate,
_TModelBoolValue,
_is_pydantic_model_static_type,
_is_pydantic_model_union_static_type,
get_model_bool_predicate,
)
from pydantic import BaseModel


class FlatInitModel[_TModel: BaseModel]:
"""Model constructor for initializing a potentially nested Pydantic model from flat kwargs.

Nested model fields of a given model are recursively resolved;
for model union fields, the first model type of the union is processed.
"""

def __init__(self, model: type[_TModel], fail_fast: bool = True):
self.model = model
self.fail_fast = fail_fast

self._curried_model = CurryModel(model=model, fail_fast=fail_fast)
self._model_bool_value: _TModelBoolValue | None = self.model.model_config.get(
"model_bool", None
)

def __call__(self, **kwargs) -> _TModel:
"""Run a FlatInitModel constructor to instantiate a Pydantic model from flat kwargs."""
for field_name, field_info in self.model.model_fields.items():
if _is_pydantic_model_static_type(field_info.annotation):
nested_model = field_info.annotation
field_value = FlatInitModel(
model=nested_model, fail_fast=self.fail_fast
)(**kwargs)

elif _is_pydantic_model_union_static_type(
model_union := field_info.annotation
):
nested_model_type: type[BaseModel] = next(
filter(_is_pydantic_model_static_type, get_args(model_union))
)
nested_model_instance = FlatInitModel(
model=nested_model_type, fail_fast=self.fail_fast
)(**kwargs)
model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(
model=nested_model_type
)
field_value = (
nested_model_instance
if model_bool_predicate(nested_model_instance)
else field_info.default
)
else:
field_value = kwargs.get(field_name, field_info.default)

self._curried_model(**{field_name: field_value})

model_instance = self._curried_model()

assert isinstance(model_instance, self.model)
return model_instance
90 changes: 90 additions & 0 deletions lupl/pydantic_tools/flat_init/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Utils for the lupl.FlatInitModel constructor."""

from types import UnionType
from typing import (
Any,
Protocol,
TypeAlias,
TypeGuard,
Union,
cast,
get_args,
get_origin,
runtime_checkable,
)

from pydantic import BaseModel, ConfigDict as PydanticConfigDict


@runtime_checkable
class ModelBoolPredicate[_TModel: BaseModel](Protocol):
"""Type for model_bool predicate functions."""

def __call__(self, model: _TModel) -> bool: ...


_TModelBoolValue: TypeAlias = ModelBoolPredicate | str | set[str]


class ConfigDict(PydanticConfigDict, total=False):
model_bool: _TModelBoolValue


def default_model_bool_predicate(model: BaseModel) -> bool:
"""Default predicate for determining model truthiness.

Adheres to ModelBoolPredicate.
"""
return any(dict(model).values())


def _get_model_bool_predicate_from_config_value(
model_bool_value: _TModelBoolValue,
) -> ModelBoolPredicate:
"""Get a model_bool predicate function given the value of the model_bool config setting."""
match model_bool_value:
case ModelBoolPredicate():
return model_bool_value
case str():
return lambda model: bool(dict(model)[model_bool_value])
case set():
return lambda model: all(map(lambda k: dict(model)[k], model_bool_value))
case _:
msg = (
f"Expected type {_TModelBoolValue} for model_bool config setting. "
f"Got '{model_bool_value}'."
)
raise ValueError(msg)


def get_model_bool_predicate(model: type[BaseModel] | BaseModel) -> ModelBoolPredicate:
"""Get the applicable model_bool predicate function given a model."""
_missing = object()
if (model_bool_value := model.model_config.get("model_bool", _missing)) is _missing:
model_bool_predicate = default_model_bool_predicate
else:
model_bool_predicate = _get_model_bool_predicate_from_config_value(
# cast and see what happens at runtime...
cast(_TModelBoolValue, model_bool_value)
)

return model_bool_predicate


def _is_pydantic_model_static_type(obj: Any) -> TypeGuard[type[BaseModel]]:
"""Check if object is a Pydantic model type."""
return (
isinstance(obj, type) and issubclass(obj, BaseModel) and (obj is not BaseModel)
)


def _is_pydantic_model_union_static_type(
obj: Any,
) -> TypeGuard[UnionType]:
"""Check if object is a union type of a Pydantic model."""
is_union_type: bool = get_origin(obj) in (UnionType, Union)
has_any_model: bool = any(
_is_pydantic_model_static_type(obj) for obj in get_args(obj)
)

return is_union_type and has_any_model
57 changes: 0 additions & 57 deletions lupl/pydantic_tools/model_constructors.py

This file was deleted.

16 changes: 0 additions & 16 deletions tests/tests_pydantic_tools/test_init_model_from_kwargs.py

This file was deleted.

Loading