From b57554fe4d6041bdc79239c673b11e71c5c72a44 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 17 Mar 2025 14:03:28 -0400 Subject: [PATCH 1/3] Initial with_reasoning implementation --- alphaswarm/core/llm/__init__.py | 2 ++ alphaswarm/core/llm/utils.py | 23 +++++++++++++++++++++++ tests/unit/core/llm/test_utils.py | 19 +++++++++++++++++++ 3 files changed, 44 insertions(+) create mode 100644 alphaswarm/core/llm/utils.py create mode 100644 tests/unit/core/llm/test_utils.py diff --git a/alphaswarm/core/llm/__init__.py b/alphaswarm/core/llm/__init__.py index cd04423b..3d819e60 100644 --- a/alphaswarm/core/llm/__init__.py +++ b/alphaswarm/core/llm/__init__.py @@ -7,6 +7,7 @@ PythonLLMFunction, ) from .message import CacheControl, ContentBlock, ImageContentBlock, ImageURL, Message, TextContentBlock +from .utils import with_reasoning __all__ = [ "LLMFunction", @@ -21,4 +22,5 @@ "Message", "TextContentBlock", "PythonLLMFunction", + "with_reasoning", ] diff --git a/alphaswarm/core/llm/utils.py b/alphaswarm/core/llm/utils.py new file mode 100644 index 00000000..c213e421 --- /dev/null +++ b/alphaswarm/core/llm/utils.py @@ -0,0 +1,23 @@ +from typing import Optional, Type, TypeVar + +from pydantic import BaseModel, Field, create_model + +T = TypeVar("T", bound=BaseModel) + + +def with_reasoning(cls: Type[T], description: Optional[str] = None) -> Type[T]: + """ + Decorator that adds a 'reasoning' field to a Pydantic model to support Chain-of-Thought pattern. + The reasoning field will be placed first in the schema. + """ + + original_fields = cls.__annotations__.copy() + + description = description or "Your reasoning to arrive at the answer" + new_fields = {"reasoning": (str, Field(description=description))} + + for field_name, field_type in original_fields.items(): + new_fields[field_name] = (field_type, cls.model_fields[field_name] if field_name in cls.model_fields else None) + + return create_model(cls.__name__, __doc__=cls.__doc__, **new_fields) + # return cast(Type[T], new_cls) diff --git a/tests/unit/core/llm/test_utils.py b/tests/unit/core/llm/test_utils.py new file mode 100644 index 00000000..f611a64e --- /dev/null +++ b/tests/unit/core/llm/test_utils.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel, Field + +from alphaswarm.core.llm import with_reasoning + + +def test_with_reasoning() -> None: + @with_reasoning + class Schema(BaseModel): + a: float = Field(..., description="Parameter a") + b: float = Field(3.14, description="Parameter b") + + class ExpectedSchema(BaseModel): + reasoning: str = Field(..., description="Your reasoning to arrive at the answer") + a: float = Field(..., description="Parameter a") + b: float = Field(3.14, description="Parameter b") + + print(Schema.model_json_schema()) + print(ExpectedSchema.model_json_schema()) + assert Schema.model_json_schema()["properties"] == ExpectedSchema.model_json_schema()["properties"] From babb6abc5932671fcc51f9faab20b52edd24d628 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 17 Mar 2025 15:59:46 -0400 Subject: [PATCH 2/3] Fix implementation + more tests --- alphaswarm/core/llm/utils.py | 22 ++++++----- tests/unit/core/llm/test_utils.py | 63 +++++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 16 deletions(-) diff --git a/alphaswarm/core/llm/utils.py b/alphaswarm/core/llm/utils.py index c213e421..810159ee 100644 --- a/alphaswarm/core/llm/utils.py +++ b/alphaswarm/core/llm/utils.py @@ -1,23 +1,27 @@ -from typing import Optional, Type, TypeVar +from typing import Callable, Optional, Type, TypeVar from pydantic import BaseModel, Field, create_model T = TypeVar("T", bound=BaseModel) -def with_reasoning(cls: Type[T], description: Optional[str] = None) -> Type[T]: +def with_reasoning(description: Optional[str] = None) -> Callable[[Type[T]], Type[T]]: """ Decorator that adds a 'reasoning' field to a Pydantic model to support Chain-of-Thought pattern. The reasoning field will be placed first in the schema. """ - original_fields = cls.__annotations__.copy() + def decorator(cls: Type[T]) -> Type[T]: + original_fields = cls.__annotations__.copy() - description = description or "Your reasoning to arrive at the answer" - new_fields = {"reasoning": (str, Field(description=description))} + new_fields = {"reasoning": (str, Field(description=description or "Your reasoning to arrive at the answer"))} - for field_name, field_type in original_fields.items(): - new_fields[field_name] = (field_type, cls.model_fields[field_name] if field_name in cls.model_fields else None) + for field_name, field_type in original_fields.items(): + new_fields[field_name] = ( + field_type, + cls.model_fields[field_name] if field_name in cls.model_fields else None, + ) - return create_model(cls.__name__, __doc__=cls.__doc__, **new_fields) - # return cast(Type[T], new_cls) + return create_model(cls.__name__, __doc__=cls.__doc__, **new_fields) # type: ignore + + return decorator diff --git a/tests/unit/core/llm/test_utils.py b/tests/unit/core/llm/test_utils.py index f611a64e..c809ed20 100644 --- a/tests/unit/core/llm/test_utils.py +++ b/tests/unit/core/llm/test_utils.py @@ -1,19 +1,68 @@ +from typing import Any, Dict, Literal, List, Type + from pydantic import BaseModel, Field from alphaswarm.core.llm import with_reasoning -def test_with_reasoning() -> None: - @with_reasoning - class Schema(BaseModel): +def schema_properties(schema: Type[BaseModel]) -> Dict[str, Any]: + return schema.model_json_schema()["properties"] + + +def schema_properties_keys(schema: Type[BaseModel]) -> List[str]: + return list(schema.model_json_schema()["properties"].keys()) + + +def test_with_reasoning_default_description() -> None: + @with_reasoning() + class Model(BaseModel): a: float = Field(..., description="Parameter a") b: float = Field(3.14, description="Parameter b") - class ExpectedSchema(BaseModel): + class ExpectedModel(BaseModel): reasoning: str = Field(..., description="Your reasoning to arrive at the answer") a: float = Field(..., description="Parameter a") b: float = Field(3.14, description="Parameter b") - print(Schema.model_json_schema()) - print(ExpectedSchema.model_json_schema()) - assert Schema.model_json_schema()["properties"] == ExpectedSchema.model_json_schema()["properties"] + assert schema_properties(Model) == schema_properties(ExpectedModel) + + +def test_with_reasoning_custom_description() -> None: + @with_reasoning(description="Custom reasoning description") + class Model(BaseModel): + other: Literal["a", "b", "c"] = Field(..., description="Other field") + + class ExpectedModel(BaseModel): + reasoning: str = Field(..., description="Custom reasoning description") + other: Literal["a", "b", "c"] = Field(..., description="Other field") + + assert schema_properties(Model) == schema_properties(ExpectedModel) + + +def test_with_reasoning_field_order() -> None: + @with_reasoning() + class Model(BaseModel): + a: float = Field(..., description="First field") + b: str = Field(..., description="Second field") + c: int = Field(..., description="Third field") + + properties = schema_properties_keys(Model) + + assert properties[0] == "reasoning" + assert properties[1:] == ["a", "b", "c"] + + +def test_with_reasoning_nested_fields() -> None: + @with_reasoning() + class NestedModel(BaseModel): + x: int = Field(..., description="Nested field") + + @with_reasoning() + class Model(BaseModel): + nested: NestedModel + other: str = Field(..., description="Other field") + + properties = schema_properties_keys(Model) + nested_properties = schema_properties_keys(NestedModel) + assert properties == ["reasoning", "nested", "other"] + assert nested_properties == ["reasoning", "x"] From 45987b392f170bf98cf132d53b9fa05ebb945292 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Mon, 17 Mar 2025 16:06:50 -0400 Subject: [PATCH 3/3] Integration test --- tests/integration/core/llm/test_llm_function.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/integration/core/llm/test_llm_function.py b/tests/integration/core/llm/test_llm_function.py index 3bf760ca..7f4e273c 100644 --- a/tests/integration/core/llm/test_llm_function.py +++ b/tests/integration/core/llm/test_llm_function.py @@ -5,7 +5,7 @@ import requests from litellm.types.utils import Usage -from alphaswarm.core.llm import ImageURL, LLMFunction, Message +from alphaswarm.core.llm import ImageURL, LLMFunction, Message, with_reasoning from pydantic import BaseModel, Field from tests import get_data_filename @@ -119,3 +119,16 @@ def test_llm_function_with_image() -> None: assert isinstance(result, TestResponse) assert "eth" in result.content.lower() assert "sol" in result.content.lower() + + +def test_llm_function_with_reasoning() -> None: + @with_reasoning(description="Reasoning behind the response") + class ReasoningResponse(BaseModel): + content: str = Field(..., description="The content of the response") + + llm_func = get_llm_function(response_model=ReasoningResponse, system_message="What's the capital of Great Britain?") + + result = llm_func.execute() + assert isinstance(result, ReasoningResponse) + assert isinstance(result.reasoning, str) # type: ignore + assert "London" in result.content