Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions alphaswarm/core/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
PythonLLMFunction,
)
from .message import CacheControl, ContentBlock, ImageContentBlock, ImageURL, Message, TextContentBlock
from .utils import with_reasoning

__all__ = [
"LLMFunction",
Expand All @@ -21,4 +22,5 @@
"Message",
"TextContentBlock",
"PythonLLMFunction",
"with_reasoning",
]
27 changes: 27 additions & 0 deletions alphaswarm/core/llm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Callable, Optional, Type, TypeVar

from pydantic import BaseModel, Field, create_model

T = TypeVar("T", bound=BaseModel)


def with_reasoning(description: Optional[str] = None) -> Callable[[Type[T]], Type[T]]:
"""
Decorator that adds a 'reasoning' field to a Pydantic model to support Chain-of-Thought pattern.
The reasoning field will be placed first in the schema.
"""

def decorator(cls: Type[T]) -> Type[T]:
original_fields = cls.__annotations__.copy()

new_fields = {"reasoning": (str, Field(description=description or "Your reasoning to arrive at the answer"))}

for field_name, field_type in original_fields.items():
new_fields[field_name] = (
field_type,
cls.model_fields[field_name] if field_name in cls.model_fields else None,
)

return create_model(cls.__name__, __doc__=cls.__doc__, **new_fields) # type: ignore

return decorator
15 changes: 14 additions & 1 deletion tests/integration/core/llm/test_llm_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests
from litellm.types.utils import Usage

from alphaswarm.core.llm import ImageURL, LLMFunction, Message
from alphaswarm.core.llm import ImageURL, LLMFunction, Message, with_reasoning
from pydantic import BaseModel, Field
from tests import get_data_filename

Expand Down Expand Up @@ -119,3 +119,16 @@ def test_llm_function_with_image() -> None:
assert isinstance(result, TestResponse)
assert "eth" in result.content.lower()
assert "sol" in result.content.lower()


def test_llm_function_with_reasoning() -> None:
@with_reasoning(description="Reasoning behind the response")
class ReasoningResponse(BaseModel):
content: str = Field(..., description="The content of the response")

llm_func = get_llm_function(response_model=ReasoningResponse, system_message="What's the capital of Great Britain?")

result = llm_func.execute()
assert isinstance(result, ReasoningResponse)
assert isinstance(result.reasoning, str) # type: ignore
assert "London" in result.content
68 changes: 68 additions & 0 deletions tests/unit/core/llm/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, Literal, List, Type

from pydantic import BaseModel, Field

from alphaswarm.core.llm import with_reasoning


def schema_properties(schema: Type[BaseModel]) -> Dict[str, Any]:
return schema.model_json_schema()["properties"]


def schema_properties_keys(schema: Type[BaseModel]) -> List[str]:
return list(schema.model_json_schema()["properties"].keys())


def test_with_reasoning_default_description() -> None:
@with_reasoning()
class Model(BaseModel):
a: float = Field(..., description="Parameter a")
b: float = Field(3.14, description="Parameter b")

class ExpectedModel(BaseModel):
reasoning: str = Field(..., description="Your reasoning to arrive at the answer")
a: float = Field(..., description="Parameter a")
b: float = Field(3.14, description="Parameter b")

assert schema_properties(Model) == schema_properties(ExpectedModel)


def test_with_reasoning_custom_description() -> None:
@with_reasoning(description="Custom reasoning description")
class Model(BaseModel):
other: Literal["a", "b", "c"] = Field(..., description="Other field")

class ExpectedModel(BaseModel):
reasoning: str = Field(..., description="Custom reasoning description")
other: Literal["a", "b", "c"] = Field(..., description="Other field")

assert schema_properties(Model) == schema_properties(ExpectedModel)


def test_with_reasoning_field_order() -> None:
@with_reasoning()
class Model(BaseModel):
a: float = Field(..., description="First field")
b: str = Field(..., description="Second field")
c: int = Field(..., description="Third field")

properties = schema_properties_keys(Model)

assert properties[0] == "reasoning"
assert properties[1:] == ["a", "b", "c"]


def test_with_reasoning_nested_fields() -> None:
@with_reasoning()
class NestedModel(BaseModel):
x: int = Field(..., description="Nested field")

@with_reasoning()
class Model(BaseModel):
nested: NestedModel
other: str = Field(..., description="Other field")

properties = schema_properties_keys(Model)
nested_properties = schema_properties_keys(NestedModel)
assert properties == ["reasoning", "nested", "other"]
assert nested_properties == ["reasoning", "x"]