From 8c7101cbc9387821a4b8a6e2351e9c11a18cd6df Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 26 Feb 2025 15:07:44 -0700 Subject: [PATCH 01/13] Base implementation of PromptConfig --- alphaswarm/core/prompt/__init__.py | 3 + alphaswarm/core/prompt/prompt.py | 49 ++++++++++++ tests/unit/core/prompt/__init__.py | 0 tests/unit/core/prompt/test_prompt.py | 105 ++++++++++++++++++++++++++ 4 files changed, 157 insertions(+) create mode 100644 alphaswarm/core/prompt/__init__.py create mode 100644 alphaswarm/core/prompt/prompt.py create mode 100644 tests/unit/core/prompt/__init__.py create mode 100644 tests/unit/core/prompt/test_prompt.py diff --git a/alphaswarm/core/prompt/__init__.py b/alphaswarm/core/prompt/__init__.py new file mode 100644 index 00000000..6434ee7a --- /dev/null +++ b/alphaswarm/core/prompt/__init__.py @@ -0,0 +1,3 @@ +from .prompt import PromptConfig + +__all__ = ["PromptConfig"] diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py new file mode 100644 index 00000000..cd1feeba --- /dev/null +++ b/alphaswarm/core/prompt/prompt.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import abc +from typing import Any, Dict, Optional + +import yaml +from pydantic import BaseModel, field_validator + + +class PromptTemplateBase(BaseModel, abc.ABC): + @abc.abstractmethod + def get_template(self) -> str: + pass + + +class PromptTemplate(PromptTemplateBase): + template: str + + def get_template(self) -> str: + return self.template + + +class PromptPair(BaseModel): + system: PromptTemplate # TODO: should be base class + user: Optional[PromptTemplate] = None + + +class LLMConfig(BaseModel): + model: str + params: Optional[Dict[str, Any]] = None + + +class PromptConfig(BaseModel): + kind: str + prompt: PromptPair + metadata: Optional[Dict[str, Any]] = None + llm: Optional[LLMConfig] = None + + @field_validator("kind") + @classmethod + def validate_kind(cls, kind: str) -> str: + if kind not in ["Prompt"]: + raise ValueError(f"Invalid kind: {kind}") + return kind + + @classmethod + def from_yaml(cls, yaml_str: str) -> PromptConfig: + data = yaml.safe_load(yaml_str) + return cls(**data) diff --git a/tests/unit/core/prompt/__init__.py b/tests/unit/core/prompt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/core/prompt/test_prompt.py b/tests/unit/core/prompt/test_prompt.py new file mode 100644 index 00000000..f5f90886 --- /dev/null +++ b/tests/unit/core/prompt/test_prompt.py @@ -0,0 +1,105 @@ +import pytest +from alphaswarm.core.prompt import PromptConfig +from alphaswarm.core.prompt.prompt import ( + PromptTemplate, + PromptPair, + LLMConfig, +) + + +class TestPromptTemplate: + def test_prompt_template(self) -> None: + template = "This is a test template with {variable}" + prompt = PromptTemplate(template=template) + assert prompt.template == template + + +class TestPromptPair: + def test_prompt_pair_with_system_only(self) -> None: + system_prompt = PromptTemplate(template="You are a helpful assistant.") + pair = PromptPair(system=system_prompt) + assert pair.system == system_prompt + assert pair.user is None + + def test_prompt_pair_with_system_and_user(self) -> None: + system_prompt = PromptTemplate(template="You are a helpful assistant.") + user_prompt = PromptTemplate(template="Help me with {task}.") + pair = PromptPair(system=system_prompt, user=user_prompt) + assert pair.system == system_prompt + assert pair.user == user_prompt + + +class TestLLMConfig: + def test_llm_config_with_model_only(self) -> None: + config = LLMConfig(model="gpt-4o") + assert config.model == "gpt-4o" + assert config.params is None + + def test_llm_config_with_params(self) -> None: + params = {"temperature": 0.7, "max_tokens": 100, "another_param": "value"} + config = LLMConfig(model="gpt-4o", params=params) + assert config.model == "gpt-4o" + assert config.params == params + + +class TestPromptConfig: + def test_prompt_config_initialization(self) -> None: + prompt_pair = PromptPair( + system=PromptTemplate(template="You are a helpful assistant."), + user=PromptTemplate(template="Help me with {task}."), + ) + metadata = {"version": "1.0", "author": "John Doe", "created_at": "2025-02-25"} + + config = PromptConfig(kind="Prompt", prompt=prompt_pair, metadata=metadata, llm=LLMConfig(model="gpt-4o")) + + assert config.kind == "Prompt" + assert config.prompt == prompt_pair + assert config.metadata == metadata + assert config.llm is not None + assert config.llm.model == "gpt-4o" + + def test_with_empty_metadata_and_llm(self) -> None: + prompt_pair = PromptPair( + system=PromptTemplate(template="You are a helpful assistant."), + user=PromptTemplate(template="Help me with {task}."), + ) + config = PromptConfig(kind="Prompt", prompt=prompt_pair) + + assert config.metadata is None + assert config.llm is None + + def test_prompt_config_invalid_kind(self) -> None: + system_prompt = PromptTemplate(template="You are a helpful assistant.") + prompt_pair = PromptPair(system=system_prompt) + + with pytest.raises(ValueError, match="Invalid kind: InvalidKind"): + PromptConfig(kind="InvalidKind", prompt=prompt_pair) + + def test_from_yaml(self) -> None: + yaml_str = """ + kind: Prompt + prompt: + system: + template: You are a helpful assistant. + user: + template: Help me with this task. + metadata: + version: "0.0.1" + llm: + model: gpt-4o + params: + temperature: 0.7 + """ + + config = PromptConfig.from_yaml(yaml_str) + + assert config.kind == "Prompt" + assert isinstance(config.prompt, PromptPair) + assert isinstance(config.prompt.system, PromptTemplate) + assert config.prompt.system.template == "You are a helpful assistant." + assert isinstance(config.prompt.user, PromptTemplate) + assert config.prompt.user.template == "Help me with this task." + assert config.metadata == {"version": "0.0.1"} + assert config.llm is not None + assert config.llm.model == "gpt-4o" + assert config.llm.params == {"temperature": 0.7} From 98385e58c1279e93d5bc99de5a7e668c7bdec076 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 26 Feb 2025 16:05:38 -0700 Subject: [PATCH 02/13] Add example files for test --- alphaswarm/core/prompt/prompt.py | 10 ++++-- tests/__init__.py | 11 +++++- tests/data/prompts/prompt.yaml | 15 ++++++++ tests/data/prompts/structured_prompt.yaml | 24 +++++++++++++ tests/unit/core/prompt/test_prompt.py | 44 ++++++++++++++--------- 5 files changed, 84 insertions(+), 20 deletions(-) create mode 100644 tests/data/prompts/prompt.yaml create mode 100644 tests/data/prompts/structured_prompt.yaml diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py index cd1feeba..d9a02391 100644 --- a/alphaswarm/core/prompt/prompt.py +++ b/alphaswarm/core/prompt/prompt.py @@ -16,6 +16,11 @@ def get_template(self) -> str: class PromptTemplate(PromptTemplateBase): template: str + @field_validator("template") + @classmethod + def strip_template(cls, template: str) -> str: + return template.strip() + def get_template(self) -> str: return self.template @@ -44,6 +49,7 @@ def validate_kind(cls, kind: str) -> str: return kind @classmethod - def from_yaml(cls, yaml_str: str) -> PromptConfig: - data = yaml.safe_load(yaml_str) + def from_yaml(cls, path: str) -> PromptConfig: + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) return cls(**data) diff --git a/tests/__init__.py b/tests/__init__.py index 4eb9068d..bde59cb5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,14 @@ import os +from typing import Final +from enum import Enum + +DATA_PATH: Final[str] = os.path.join(os.path.dirname(__file__), "data") + + +class PromptPath(str, Enum): + basic = os.path.join(DATA_PATH, "prompts", "prompt.yaml") + structured = os.path.join(DATA_PATH, "prompts", "structured_prompt.yaml") def get_data_filename(filename: str) -> str: - return os.path.join(os.path.dirname(__file__), "data", filename) + return os.path.join(DATA_PATH, filename) diff --git a/tests/data/prompts/prompt.yaml b/tests/data/prompts/prompt.yaml new file mode 100644 index 00000000..eabd3803 --- /dev/null +++ b/tests/data/prompts/prompt.yaml @@ -0,0 +1,15 @@ +kind: Prompt +metadata: + description: | + This is a prompt doing abc +llm: + model: gpt-4o-mini + params: + temperature: 0.3 +prompt: + system: + template: | + You are a helpful assistant. + user: + template: | + Answer the following questions: {question} diff --git a/tests/data/prompts/structured_prompt.yaml b/tests/data/prompts/structured_prompt.yaml new file mode 100644 index 00000000..9916b0a0 --- /dev/null +++ b/tests/data/prompts/structured_prompt.yaml @@ -0,0 +1,24 @@ +kind: StructuredPrompt +metadata: + description: | + This is a prompt doing xyz +llm: + model: claude-3-5-haiku-20241022 + params: + temperature: 0.2 +prompt: + system: + sections: + - name: Instructions + content: | + You are a helpful assistant. + sections: + - name: Hints + content: | + Answer the question in a concise manner. + user: + sections: + - name: Question + content: | + What's the capital of France? + formatter: XML diff --git a/tests/unit/core/prompt/test_prompt.py b/tests/unit/core/prompt/test_prompt.py index f5f90886..a96f4d99 100644 --- a/tests/unit/core/prompt/test_prompt.py +++ b/tests/unit/core/prompt/test_prompt.py @@ -5,6 +5,7 @@ PromptPair, LLMConfig, ) +from tests import PromptPath class TestPromptTemplate: @@ -75,23 +76,18 @@ def test_prompt_config_invalid_kind(self) -> None: with pytest.raises(ValueError, match="Invalid kind: InvalidKind"): PromptConfig(kind="InvalidKind", prompt=prompt_pair) - def test_from_yaml(self) -> None: - yaml_str = """ - kind: Prompt - prompt: - system: - template: You are a helpful assistant. - user: - template: Help me with this task. - metadata: - version: "0.0.1" - llm: - model: gpt-4o - params: - temperature: 0.7 - """ - - config = PromptConfig.from_yaml(yaml_str) + def test_from_dict(self) -> None: + data = { + "kind": "Prompt", + "prompt": { + "system": {"template": "You are a helpful assistant."}, + "user": {"template": "Help me with this task."}, + }, + "metadata": {"version": "0.0.1"}, + "llm": {"model": "gpt-4o", "params": {"temperature": 0.7}}, + } + + config = PromptConfig(**data) # type: ignore assert config.kind == "Prompt" assert isinstance(config.prompt, PromptPair) @@ -103,3 +99,17 @@ def test_from_yaml(self) -> None: assert config.llm is not None assert config.llm.model == "gpt-4o" assert config.llm.params == {"temperature": 0.7} + + def test_from_file(self) -> None: + config = PromptConfig.from_yaml(PromptPath.basic) + + assert config.kind == "Prompt" + assert isinstance(config.prompt, PromptPair) + assert isinstance(config.prompt.system, PromptTemplate) + assert config.prompt.system.template == "You are a helpful assistant." + assert isinstance(config.prompt.user, PromptTemplate) + assert config.prompt.user.template == "Answer the following questions: {question}" + assert config.metadata == {"description": "This is a prompt doing abc\n"} + assert config.llm is not None + assert config.llm.model == "gpt-4o-mini" + assert config.llm.params == {"temperature": 0.3} From da543fa26e8d4603d097c2995f4859dbca902cd4 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 26 Feb 2025 16:18:04 -0700 Subject: [PATCH 03/13] Enhance LLMFunctionTemplated --- alphaswarm/core/llm/llm_function.py | 59 ++++++++++++++++++- .../core/llm/test_llm_function_from_files.py | 38 ++++++++---- 2 files changed, 86 insertions(+), 11 deletions(-) diff --git a/alphaswarm/core/llm/llm_function.py b/alphaswarm/core/llm/llm_function.py index 15452fa2..ae849052 100644 --- a/alphaswarm/core/llm/llm_function.py +++ b/alphaswarm/core/llm/llm_function.py @@ -9,6 +9,7 @@ from litellm.types.utils import ModelResponse from pydantic import BaseModel +from ..prompt import PromptConfig from .message import Message litellm.modify_params = True # for calls with system message only for anthropic @@ -223,7 +224,7 @@ def from_files( system_prompt_params: Optional[Dict[str, Any]] = None, max_retries: int = 3, ) -> LLMFunctionTemplated[T_Response]: - """Create an instance from template files. + """Create an instance from template text files. Args: model_id: LiteLLM model ID to use @@ -250,6 +251,62 @@ def from_files( max_retries=max_retries, ) + @classmethod + def from_prompt_config( + cls, + response_model: Type[T_Response], + prompt_config: PromptConfig, + system_prompt_params: Optional[Dict[str, Any]] = None, + max_retries: int = 3, + ) -> LLMFunctionTemplated[T_Response]: + """Create an instance from prompt config object. + + Args: + response_model: Pydantic model class for structuring responses + prompt_config: PromptConfig object + system_prompt_params: Parameters for formatting the system prompt + max_retries: Maximum number of retry attempts + """ + system_prompt_template = prompt_config.prompt.system.template + user_prompt_template = prompt_config.prompt.user.template if prompt_config.prompt.user else None + + if prompt_config.llm is None: + raise ValueError("LLMConfig not set in PromptConfig") + model_id = prompt_config.llm.model + # TODO: pass kwargs in the __init__ + + return cls( + model_id=model_id, + response_model=response_model, + system_prompt_template=system_prompt_template, + user_prompt_template=user_prompt_template, + system_prompt_params=system_prompt_params, + max_retries=max_retries, + ) + + @classmethod + def from_prompt_config_file( + cls, + response_model: Type[T_Response], + prompt_config_path: str, + system_prompt_params: Optional[Dict[str, Any]] = None, + max_retries: int = 3, + ) -> LLMFunctionTemplated[T_Response]: + """Create an instance from prompt config file. + + Args: + response_model: Pydantic model class for structuring responses + prompt_config_path: Path to the prompt config yaml file + system_prompt_params: Parameters for formatting the system prompt + max_retries: Maximum number of retry attempts + """ + return cls.from_prompt_config( + response_model=response_model, + prompt_config=PromptConfig.from_yaml(prompt_config_path), + system_prompt_params=system_prompt_params, + max_retries=max_retries, + ) + @staticmethod def _format(template: str, params: Optional[Dict[str, Any]] = None) -> str: """Format the template string with the given optional parameters.""" diff --git a/tests/unit/core/llm/test_llm_function_from_files.py b/tests/unit/core/llm/test_llm_function_from_files.py index 856ada83..8f7723e9 100644 --- a/tests/unit/core/llm/test_llm_function_from_files.py +++ b/tests/unit/core/llm/test_llm_function_from_files.py @@ -1,30 +1,48 @@ import tempfile -from typing import Any import pytest from pydantic import BaseModel from alphaswarm.core.llm import LLMFunctionTemplated +from alphaswarm.core.prompt import PromptConfig + +from tests import PromptPath class Response(BaseModel): test: str -def get_sample_llm_function_from_files(**kwargs: Any) -> LLMFunctionTemplated[Response]: - return LLMFunctionTemplated.from_files( - model_id="test", - response_model=Response, - **kwargs, - ) - - def test_execute_with_user_prompt_params_but_no_template_raises() -> None: with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".txt", delete=True) as system_file: system_file.write("Sample system prompt") system_file.flush() - llm_func = get_sample_llm_function_from_files(system_prompt_path=system_file.name) + llm_func = LLMFunctionTemplated.from_files( + model_id="test", + response_model=Response, + system_prompt_path=system_file.name, + ) with pytest.raises(ValueError, match="User prompt params provided but no user prompt template exists"): llm_func.execute(user_prompt_params={"test": "value"}) + + +def test_from_prompt_config() -> None: + llm_func_v1 = LLMFunctionTemplated.from_prompt_config( + response_model=Response, + prompt_config=PromptConfig.from_yaml(PromptPath.basic), + ) + + llm_func_v2 = LLMFunctionTemplated.from_prompt_config_file( + response_model=Response, + prompt_config_path=PromptPath.basic, + ) + + assert llm_func_v1._model_id == llm_func_v2._model_id == "gpt-4o-mini" + assert llm_func_v1.system_prompt == llm_func_v2.system_prompt == "You are a helpful assistant." + assert ( + llm_func_v1.user_prompt_template + == llm_func_v2.user_prompt_template + == "Answer the following questions: {question}" + ) From ded2b1db0d4f8476850ab153b02a1a529bc460f8 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 26 Feb 2025 16:46:38 -0700 Subject: [PATCH 04/13] WIP structured prompt implementation --- alphaswarm/core/llm/llm_function.py | 4 +- alphaswarm/core/prompt/prompt.py | 119 ++++++++++++++++-- .../core/llm/test_llm_function_from_files.py | 25 ++++ 3 files changed, 134 insertions(+), 14 deletions(-) diff --git a/alphaswarm/core/llm/llm_function.py b/alphaswarm/core/llm/llm_function.py index ae849052..e50b0775 100644 --- a/alphaswarm/core/llm/llm_function.py +++ b/alphaswarm/core/llm/llm_function.py @@ -267,8 +267,8 @@ def from_prompt_config( system_prompt_params: Parameters for formatting the system prompt max_retries: Maximum number of retry attempts """ - system_prompt_template = prompt_config.prompt.system.template - user_prompt_template = prompt_config.prompt.user.template if prompt_config.prompt.user else None + system_prompt_template = prompt_config.prompt.system.get_template() + user_prompt_template = prompt_config.prompt.user.get_template() if prompt_config.prompt.user else None if prompt_config.llm is None: raise ValueError("LLMConfig not set in PromptConfig") diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py index d9a02391..03eadcfc 100644 --- a/alphaswarm/core/prompt/prompt.py +++ b/alphaswarm/core/prompt/prompt.py @@ -1,10 +1,73 @@ from __future__ import annotations import abc -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union import yaml -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, model_validator + + +class PromptSection(BaseModel): + name: str + content: Optional[str] = None + sections: List[PromptSection] = [] + + @field_validator("content") + @classmethod + def strip_content(cls, content: Optional[str]) -> Optional[str]: + if isinstance(content, str): + return content.strip() + return content + + +class PromptFormatter(abc.ABC): + def format(self, sections: Sequence[PromptSection]) -> str: + return "\n".join(self._format_section(section) for section in sections) + + @abc.abstractmethod + def _format_section(self, section: PromptSection) -> str: + pass + + +class StringPromptFormatter(PromptFormatter): + def __init__(self, section_prefix: str = ""): + self.section_prefix = section_prefix + + def _format_section(self, section: PromptSection) -> str: + parts = [f"{self.section_prefix}{section.name}"] + if section.content: + parts.append(section.content) + parts.extend([self._format_section(sec) for sec in section.sections]) + return "\n".join(parts) + + +class MarkdownPromptFormatter(PromptFormatter): + def _format_section(self, section: PromptSection, indent: int = 1) -> str: + parts = [f"{'#' * indent} {section.name}", ""] + if section.content: + parts.extend([section.content, ""]) + parts.extend([self._format_section(sec, indent + 1) for sec in section.sections]) + return "\n".join(parts) + + +class XMLPromptFormatter(PromptFormatter): + INDENT_DIFF: str = " " + + def to_snake_case(self, string: str) -> str: + return string.lower().replace(" ", "_") + + def _format_section(self, section: PromptSection, indent: str = "") -> str: + name_snake_case = self.to_snake_case(section.name) + parts = [f"{indent}<{name_snake_case}>"] + + if section.content: + content_lines = section.content.split("\n") + content = "\n".join([f"{indent}{self.INDENT_DIFF}{line}" for line in content_lines]) + parts.append(content) + + parts.extend([self._format_section(sec, indent + self.INDENT_DIFF) for sec in section.sections]) + parts.append(f"{indent}") + return "\n".join(parts) class PromptTemplateBase(BaseModel, abc.ABC): @@ -25,29 +88,61 @@ def get_template(self) -> str: return self.template +class StructuredPromptTemplate(PromptTemplateBase): + sections: List[PromptSection] + + def get_template(self) -> str: + return self._formatter.format(self.sections) + + class PromptPair(BaseModel): - system: PromptTemplate # TODO: should be base class + system: PromptTemplate user: Optional[PromptTemplate] = None +FORMATTER_REGISTRY: Mapping[str, Type[PromptFormatter]] = { + "string": StringPromptFormatter, + "markdown": MarkdownPromptFormatter, + "xml": XMLPromptFormatter, +} + + +class StructuredPromptPair(BaseModel): + system: StructuredPromptTemplate + user: Optional[StructuredPromptTemplate] = None + formatter: str = "string" + + @staticmethod + def resolve_formatter(formatter: Union[str, PromptFormatter]) -> PromptFormatter: + # TODO: save in _formatter + if isinstance(formatter, PromptFormatter): + return formatter + if formatter.lower() in FORMATTER_REGISTRY: + return FORMATTER_REGISTRY[formatter.lower()]() + raise ValueError( + f"Unknown formatter: `{formatter}`. Available formatters: {', '.join(FORMATTER_REGISTRY.keys())}" + ) + + @model_validator(mode="after") + def set_formatter(self): + self._formatter = self.resolve_formatter(self.formatter) + self.system._formatter = self._formatter + if self.user: + self.user._formatter = self._formatter + + return self + + class LLMConfig(BaseModel): model: str params: Optional[Dict[str, Any]] = None class PromptConfig(BaseModel): - kind: str - prompt: PromptPair + prompt: Union[PromptPair, StructuredPromptPair] metadata: Optional[Dict[str, Any]] = None llm: Optional[LLMConfig] = None - @field_validator("kind") - @classmethod - def validate_kind(cls, kind: str) -> str: - if kind not in ["Prompt"]: - raise ValueError(f"Invalid kind: {kind}") - return kind - @classmethod def from_yaml(cls, path: str) -> PromptConfig: with open(path, "r", encoding="utf-8") as f: diff --git a/tests/unit/core/llm/test_llm_function_from_files.py b/tests/unit/core/llm/test_llm_function_from_files.py index 8f7723e9..a587a5fb 100644 --- a/tests/unit/core/llm/test_llm_function_from_files.py +++ b/tests/unit/core/llm/test_llm_function_from_files.py @@ -46,3 +46,28 @@ def test_from_prompt_config() -> None: == llm_func_v2.user_prompt_template == "Answer the following questions: {question}" ) + + +def test_from_structured_prompt_config() -> None: + # TODO tests prompt + llm_func = LLMFunctionTemplated.from_prompt_config_file( + response_model=Response, + prompt_config_path=PromptPath.structured, + ) + + assert llm_func._model_id == "claude-3-5-haiku-20241022" + assert ( + llm_func.system_prompt + == """ + You are a helpful assistant. + + Answer the question in a concise manner. + +""" + ) + assert ( + llm_func.user_prompt_template + == """ + What's the capital of France? +""" + ) From 752b51171f9d60c63f1937705c262f8cee6e5bfc Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 26 Feb 2025 17:05:57 -0700 Subject: [PATCH 05/13] Code restructuring --- alphaswarm/core/prompt/base.py | 13 +++ alphaswarm/core/prompt/prompt.py | 115 +-------------------------- alphaswarm/core/prompt/structured.py | 108 +++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 111 deletions(-) create mode 100644 alphaswarm/core/prompt/base.py create mode 100644 alphaswarm/core/prompt/structured.py diff --git a/alphaswarm/core/prompt/base.py b/alphaswarm/core/prompt/base.py new file mode 100644 index 00000000..d9bccb45 --- /dev/null +++ b/alphaswarm/core/prompt/base.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import abc + +from pydantic import BaseModel + + +class PromptTemplateBase(BaseModel, abc.ABC): + @abc.abstractmethod + def get_template(self) -> str: + pass + +# TODO: base prompt pair diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py index 03eadcfc..e44e0604 100644 --- a/alphaswarm/core/prompt/prompt.py +++ b/alphaswarm/core/prompt/prompt.py @@ -1,79 +1,12 @@ from __future__ import annotations -import abc -from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Dict, Optional, Union import yaml -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, field_validator - -class PromptSection(BaseModel): - name: str - content: Optional[str] = None - sections: List[PromptSection] = [] - - @field_validator("content") - @classmethod - def strip_content(cls, content: Optional[str]) -> Optional[str]: - if isinstance(content, str): - return content.strip() - return content - - -class PromptFormatter(abc.ABC): - def format(self, sections: Sequence[PromptSection]) -> str: - return "\n".join(self._format_section(section) for section in sections) - - @abc.abstractmethod - def _format_section(self, section: PromptSection) -> str: - pass - - -class StringPromptFormatter(PromptFormatter): - def __init__(self, section_prefix: str = ""): - self.section_prefix = section_prefix - - def _format_section(self, section: PromptSection) -> str: - parts = [f"{self.section_prefix}{section.name}"] - if section.content: - parts.append(section.content) - parts.extend([self._format_section(sec) for sec in section.sections]) - return "\n".join(parts) - - -class MarkdownPromptFormatter(PromptFormatter): - def _format_section(self, section: PromptSection, indent: int = 1) -> str: - parts = [f"{'#' * indent} {section.name}", ""] - if section.content: - parts.extend([section.content, ""]) - parts.extend([self._format_section(sec, indent + 1) for sec in section.sections]) - return "\n".join(parts) - - -class XMLPromptFormatter(PromptFormatter): - INDENT_DIFF: str = " " - - def to_snake_case(self, string: str) -> str: - return string.lower().replace(" ", "_") - - def _format_section(self, section: PromptSection, indent: str = "") -> str: - name_snake_case = self.to_snake_case(section.name) - parts = [f"{indent}<{name_snake_case}>"] - - if section.content: - content_lines = section.content.split("\n") - content = "\n".join([f"{indent}{self.INDENT_DIFF}{line}" for line in content_lines]) - parts.append(content) - - parts.extend([self._format_section(sec, indent + self.INDENT_DIFF) for sec in section.sections]) - parts.append(f"{indent}") - return "\n".join(parts) - - -class PromptTemplateBase(BaseModel, abc.ABC): - @abc.abstractmethod - def get_template(self) -> str: - pass +from .base import PromptTemplateBase +from .structured import StructuredPromptPair class PromptTemplate(PromptTemplateBase): @@ -88,51 +21,11 @@ def get_template(self) -> str: return self.template -class StructuredPromptTemplate(PromptTemplateBase): - sections: List[PromptSection] - - def get_template(self) -> str: - return self._formatter.format(self.sections) - - class PromptPair(BaseModel): system: PromptTemplate user: Optional[PromptTemplate] = None -FORMATTER_REGISTRY: Mapping[str, Type[PromptFormatter]] = { - "string": StringPromptFormatter, - "markdown": MarkdownPromptFormatter, - "xml": XMLPromptFormatter, -} - - -class StructuredPromptPair(BaseModel): - system: StructuredPromptTemplate - user: Optional[StructuredPromptTemplate] = None - formatter: str = "string" - - @staticmethod - def resolve_formatter(formatter: Union[str, PromptFormatter]) -> PromptFormatter: - # TODO: save in _formatter - if isinstance(formatter, PromptFormatter): - return formatter - if formatter.lower() in FORMATTER_REGISTRY: - return FORMATTER_REGISTRY[formatter.lower()]() - raise ValueError( - f"Unknown formatter: `{formatter}`. Available formatters: {', '.join(FORMATTER_REGISTRY.keys())}" - ) - - @model_validator(mode="after") - def set_formatter(self): - self._formatter = self.resolve_formatter(self.formatter) - self.system._formatter = self._formatter - if self.user: - self.user._formatter = self._formatter - - return self - - class LLMConfig(BaseModel): model: str params: Optional[Dict[str, Any]] = None diff --git a/alphaswarm/core/prompt/structured.py b/alphaswarm/core/prompt/structured.py new file mode 100644 index 00000000..dd18f8c5 --- /dev/null +++ b/alphaswarm/core/prompt/structured.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import abc +from typing import Optional, List, Sequence, Mapping, Type, Union + +from pydantic import BaseModel, field_validator, model_validator + +from .base import PromptTemplateBase + + +class PromptSection(BaseModel): + name: str + content: Optional[str] = None + sections: List[PromptSection] = [] + + @field_validator("content") + @classmethod + def strip_content(cls, content: Optional[str]) -> Optional[str]: + if isinstance(content, str): + return content.strip() + return content + + +class PromptFormatter(abc.ABC): + def format(self, sections: Sequence[PromptSection]) -> str: + return "\n".join(self._format_section(section) for section in sections) + + @abc.abstractmethod + def _format_section(self, section: PromptSection) -> str: + pass + + +class StringPromptFormatter(PromptFormatter): + def __init__(self, section_prefix: str = ""): + self.section_prefix = section_prefix + + def _format_section(self, section: PromptSection) -> str: + parts = [f"{self.section_prefix}{section.name}"] + if section.content: + parts.append(section.content) + parts.extend([self._format_section(sec) for sec in section.sections]) + return "\n".join(parts) + + +class MarkdownPromptFormatter(PromptFormatter): + def _format_section(self, section: PromptSection, indent: int = 1) -> str: + parts = [f"{'#' * indent} {section.name}", ""] + if section.content: + parts.extend([section.content, ""]) + parts.extend([self._format_section(sec, indent + 1) for sec in section.sections]) + return "\n".join(parts) + + +class XMLPromptFormatter(PromptFormatter): + INDENT_DIFF: str = " " + + def _format_section(self, section: PromptSection, indent: str = "") -> str: + name_snake_case = section.name.lower().replace(" ", "_") + parts = [f"{indent}<{name_snake_case}>"] + + if section.content: + content_lines = section.content.split("\n") + content = "\n".join([f"{indent}{self.INDENT_DIFF}{line}" for line in content_lines]) + parts.append(content) + + parts.extend([self._format_section(sec, indent + self.INDENT_DIFF) for sec in section.sections]) + parts.append(f"{indent}") + return "\n".join(parts) + + +FORMATTER_REGISTRY: Mapping[str, Type[PromptFormatter]] = { + "string": StringPromptFormatter, + "markdown": MarkdownPromptFormatter, + "xml": XMLPromptFormatter, +} + + +class StructuredPromptTemplate(PromptTemplateBase): + sections: List[PromptSection] + + def get_template(self) -> str: + return self._formatter.format(self.sections) + + +class StructuredPromptPair(BaseModel): + system: StructuredPromptTemplate + user: Optional[StructuredPromptTemplate] = None + formatter: str = "string" + + @staticmethod + def resolve_formatter(formatter: Union[str, PromptFormatter]) -> PromptFormatter: + # TODO: save in _formatter + if isinstance(formatter, PromptFormatter): + return formatter + if formatter.lower() in FORMATTER_REGISTRY: + return FORMATTER_REGISTRY[formatter.lower()]() + raise ValueError( + f"Unknown formatter: `{formatter}`. Available formatters: {', '.join(FORMATTER_REGISTRY.keys())}" + ) + + @model_validator(mode="after") + def set_formatter(self): + self._formatter = self.resolve_formatter(self.formatter) + self.system._formatter = self._formatter + if self.user: + self.user._formatter = self._formatter + + return self From 65d23a31bc388fa7b4d66bb697d9c9b6672b6fe7 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 26 Feb 2025 20:37:18 -0700 Subject: [PATCH 06/13] WIP --- alphaswarm/core/prompt/base.py | 6 +++- alphaswarm/core/prompt/prompt.py | 7 ++-- alphaswarm/core/prompt/structured.py | 54 ++++++++++++++++------------ 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/alphaswarm/core/prompt/base.py b/alphaswarm/core/prompt/base.py index d9bccb45..bd20f800 100644 --- a/alphaswarm/core/prompt/base.py +++ b/alphaswarm/core/prompt/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +from typing import Any, Optional from pydantic import BaseModel @@ -10,4 +11,7 @@ class PromptTemplateBase(BaseModel, abc.ABC): def get_template(self) -> str: pass -# TODO: base prompt pair + +class PromptPairBase(BaseModel): + system: Any + user: Optional[Any] = None diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py index e44e0604..f715f2be 100644 --- a/alphaswarm/core/prompt/prompt.py +++ b/alphaswarm/core/prompt/prompt.py @@ -5,7 +5,7 @@ import yaml from pydantic import BaseModel, field_validator -from .base import PromptTemplateBase +from .base import PromptPairBase, PromptTemplateBase from .structured import StructuredPromptPair @@ -15,20 +15,21 @@ class PromptTemplate(PromptTemplateBase): @field_validator("template") @classmethod def strip_template(cls, template: str) -> str: + # TODO: use StringConstraints in these cases return template.strip() def get_template(self) -> str: return self.template -class PromptPair(BaseModel): +class PromptPair(PromptPairBase): system: PromptTemplate user: Optional[PromptTemplate] = None class LLMConfig(BaseModel): model: str - params: Optional[Dict[str, Any]] = None + params: Dict[str, Any] = {} class PromptConfig(BaseModel): diff --git a/alphaswarm/core/prompt/structured.py b/alphaswarm/core/prompt/structured.py index dd18f8c5..5ac67440 100644 --- a/alphaswarm/core/prompt/structured.py +++ b/alphaswarm/core/prompt/structured.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import Optional, List, Sequence, Mapping, Type, Union +from typing import List, Mapping, Optional, Sequence, Type from pydantic import BaseModel, field_validator, model_validator @@ -21,7 +21,7 @@ def strip_content(cls, content: Optional[str]) -> Optional[str]: return content -class PromptFormatter(abc.ABC): +class PromptFormatterBase(abc.ABC): def format(self, sections: Sequence[PromptSection]) -> str: return "\n".join(self._format_section(section) for section in sections) @@ -30,7 +30,7 @@ def _format_section(self, section: PromptSection) -> str: pass -class StringPromptFormatter(PromptFormatter): +class StringPromptFormatter(PromptFormatterBase): def __init__(self, section_prefix: str = ""): self.section_prefix = section_prefix @@ -42,7 +42,7 @@ def _format_section(self, section: PromptSection) -> str: return "\n".join(parts) -class MarkdownPromptFormatter(PromptFormatter): +class MarkdownPromptFormatter(PromptFormatterBase): def _format_section(self, section: PromptSection, indent: int = 1) -> str: parts = [f"{'#' * indent} {section.name}", ""] if section.content: @@ -51,7 +51,7 @@ def _format_section(self, section: PromptSection, indent: int = 1) -> str: return "\n".join(parts) -class XMLPromptFormatter(PromptFormatter): +class XMLPromptFormatter(PromptFormatterBase): INDENT_DIFF: str = " " def _format_section(self, section: PromptSection, indent: str = "") -> str: @@ -68,7 +68,7 @@ def _format_section(self, section: PromptSection, indent: str = "") -> str: return "\n".join(parts) -FORMATTER_REGISTRY: Mapping[str, Type[PromptFormatter]] = { +FORMATTER_REGISTRY: Mapping[str, Type[PromptFormatterBase]] = { "string": StringPromptFormatter, "markdown": MarkdownPromptFormatter, "xml": XMLPromptFormatter, @@ -77,6 +77,10 @@ def _format_section(self, section: PromptSection, indent: str = "") -> str: class StructuredPromptTemplate(PromptTemplateBase): sections: List[PromptSection] + _formatter: PromptFormatterBase + + def set_formatter(self, formatter: PromptFormatterBase) -> None: + self._formatter = formatter def get_template(self) -> str: return self._formatter.format(self.sections) @@ -86,23 +90,27 @@ class StructuredPromptPair(BaseModel): system: StructuredPromptTemplate user: Optional[StructuredPromptTemplate] = None formatter: str = "string" - - @staticmethod - def resolve_formatter(formatter: Union[str, PromptFormatter]) -> PromptFormatter: - # TODO: save in _formatter - if isinstance(formatter, PromptFormatter): - return formatter - if formatter.lower() in FORMATTER_REGISTRY: - return FORMATTER_REGISTRY[formatter.lower()]() - raise ValueError( - f"Unknown formatter: `{formatter}`. Available formatters: {', '.join(FORMATTER_REGISTRY.keys())}" - ) + _formatter: PromptFormatterBase @model_validator(mode="after") - def set_formatter(self): - self._formatter = self.resolve_formatter(self.formatter) - self.system._formatter = self._formatter - if self.user: - self.user._formatter = self._formatter - + def formatter_validator(self) -> StructuredPromptPair: + formatter: PromptFormatterBase = self.formatter_string_to_obj(self.formatter) + self.set_formatter(formatter) return self + + @staticmethod + def formatter_string_to_obj(formatter: str) -> PromptFormatterBase: + formatter = formatter.lower() + if formatter not in FORMATTER_REGISTRY: + raise ValueError( + f"Unknown formatter: `{formatter}`. Available formatters: {', '.join(FORMATTER_REGISTRY.keys())}" + ) + + formatter_cls = FORMATTER_REGISTRY[formatter] + return formatter_cls() + + def set_formatter(self, formatter: PromptFormatterBase) -> None: + self._formatter = formatter + self.system.set_formatter(self._formatter) + if self.user: + self.user.set_formatter(self._formatter) From ddde5c046f60f12835fbce443edf5cdda2172ad0 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 5 Mar 2025 15:39:35 -0500 Subject: [PATCH 07/13] Code clean up with StringConstraints --- alphaswarm/core/prompt/base.py | 2 -- alphaswarm/core/prompt/prompt.py | 12 +++--------- alphaswarm/core/prompt/structured.py | 17 +++++------------ 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/alphaswarm/core/prompt/base.py b/alphaswarm/core/prompt/base.py index bd20f800..e5675ef6 100644 --- a/alphaswarm/core/prompt/base.py +++ b/alphaswarm/core/prompt/base.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import abc from typing import Any, Optional diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py index f715f2be..5266e9c7 100644 --- a/alphaswarm/core/prompt/prompt.py +++ b/alphaswarm/core/prompt/prompt.py @@ -1,22 +1,16 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Union +from typing import Annotated, Any, Dict, Optional, Union import yaml -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, StringConstraints from .base import PromptPairBase, PromptTemplateBase from .structured import StructuredPromptPair class PromptTemplate(PromptTemplateBase): - template: str - - @field_validator("template") - @classmethod - def strip_template(cls, template: str) -> str: - # TODO: use StringConstraints in these cases - return template.strip() + template: Annotated[str, StringConstraints(strip_whitespace=True)] def get_template(self) -> str: return self.template diff --git a/alphaswarm/core/prompt/structured.py b/alphaswarm/core/prompt/structured.py index 5ac67440..de59f2db 100644 --- a/alphaswarm/core/prompt/structured.py +++ b/alphaswarm/core/prompt/structured.py @@ -1,25 +1,18 @@ from __future__ import annotations import abc -from typing import List, Mapping, Optional, Sequence, Type +from typing import Annotated, List, Mapping, Optional, Sequence, Type -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, StringConstraints, model_validator -from .base import PromptTemplateBase +from .base import PromptPairBase, PromptTemplateBase class PromptSection(BaseModel): name: str - content: Optional[str] = None + content: Optional[Annotated[str, StringConstraints(strip_whitespace=True)]] = None sections: List[PromptSection] = [] - @field_validator("content") - @classmethod - def strip_content(cls, content: Optional[str]) -> Optional[str]: - if isinstance(content, str): - return content.strip() - return content - class PromptFormatterBase(abc.ABC): def format(self, sections: Sequence[PromptSection]) -> str: @@ -86,7 +79,7 @@ def get_template(self) -> str: return self._formatter.format(self.sections) -class StructuredPromptPair(BaseModel): +class StructuredPromptPair(PromptPairBase): system: StructuredPromptTemplate user: Optional[StructuredPromptTemplate] = None formatter: str = "string" From 8bfdd20faaba15476b4b3dfe57a9b402d57059eb Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 12 Mar 2025 15:28:55 -0400 Subject: [PATCH 08/13] Updates - move formatter to prompt template instead of object --- alphaswarm/core/prompt/base.py | 7 ++-- alphaswarm/core/prompt/prompt.py | 21 +++++++++--- alphaswarm/core/prompt/structured.py | 39 ++++++++++------------- tests/data/prompts/structured_prompt.yaml | 3 +- tests/unit/core/prompt/test_prompt.py | 6 ++-- 5 files changed, 43 insertions(+), 33 deletions(-) diff --git a/alphaswarm/core/prompt/base.py b/alphaswarm/core/prompt/base.py index e5675ef6..b00d6d3f 100644 --- a/alphaswarm/core/prompt/base.py +++ b/alphaswarm/core/prompt/base.py @@ -1,7 +1,10 @@ import abc -from typing import Any, Optional +from typing import Annotated, Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, StringConstraints + +# helper class alias for str that's automatically stripped +StrippedStr = Annotated[str, StringConstraints(strip_whitespace=True)] class PromptTemplateBase(BaseModel, abc.ABC): diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py index 5266e9c7..a994a517 100644 --- a/alphaswarm/core/prompt/prompt.py +++ b/alphaswarm/core/prompt/prompt.py @@ -1,16 +1,16 @@ from __future__ import annotations -from typing import Annotated, Any, Dict, Optional, Union +from typing import Any, Dict, Literal, Optional, Union import yaml -from pydantic import BaseModel, StringConstraints +from pydantic import BaseModel -from .base import PromptPairBase, PromptTemplateBase +from .base import PromptPairBase, PromptTemplateBase, StrippedStr from .structured import StructuredPromptPair class PromptTemplate(PromptTemplateBase): - template: Annotated[str, StringConstraints(strip_whitespace=True)] + template: StrippedStr def get_template(self) -> str: return self.template @@ -23,14 +23,25 @@ class PromptPair(PromptPairBase): class LLMConfig(BaseModel): model: str - params: Dict[str, Any] = {} + params: Optional[Dict[str, Any]] = None class PromptConfig(BaseModel): + """ + Prompt configuration object. + Contains prompt pair, optional metadata, and optional LLM configuration. + If LLM configuration is specified, it could be used to generate an LLMFunction. + """ + + kind: Literal["Prompt", "StructuredPrompt"] prompt: Union[PromptPair, StructuredPromptPair] metadata: Optional[Dict[str, Any]] = None llm: Optional[LLMConfig] = None + @property + def has_llm_config(self) -> bool: + return self.llm is not None + @classmethod def from_yaml(cls, path: str) -> PromptConfig: with open(path, "r", encoding="utf-8") as f: diff --git a/alphaswarm/core/prompt/structured.py b/alphaswarm/core/prompt/structured.py index de59f2db..7e0ea728 100644 --- a/alphaswarm/core/prompt/structured.py +++ b/alphaswarm/core/prompt/structured.py @@ -1,16 +1,16 @@ from __future__ import annotations import abc -from typing import Annotated, List, Mapping, Optional, Sequence, Type +from typing import List, Mapping, Optional, Sequence, Type -from pydantic import BaseModel, StringConstraints, model_validator +from pydantic import BaseModel, model_validator -from .base import PromptPairBase, PromptTemplateBase +from .base import PromptPairBase, PromptTemplateBase, StrippedStr class PromptSection(BaseModel): name: str - content: Optional[Annotated[str, StringConstraints(strip_whitespace=True)]] = None + content: Optional[StrippedStr] = None sections: List[PromptSection] = [] @@ -70,25 +70,19 @@ def _format_section(self, section: PromptSection, indent: str = "") -> str: class StructuredPromptTemplate(PromptTemplateBase): sections: List[PromptSection] - _formatter: PromptFormatterBase - - def set_formatter(self, formatter: PromptFormatterBase) -> None: - self._formatter = formatter + formatter: str = "string" + _formatter_obj: PromptFormatterBase def get_template(self) -> str: - return self._formatter.format(self.sections) - + return self._formatter_obj.format(self.sections) -class StructuredPromptPair(PromptPairBase): - system: StructuredPromptTemplate - user: Optional[StructuredPromptTemplate] = None - formatter: str = "string" - _formatter: PromptFormatterBase + def set_formatter(self, formatter: PromptFormatterBase) -> None: + self._formatter_obj = formatter @model_validator(mode="after") - def formatter_validator(self) -> StructuredPromptPair: - formatter: PromptFormatterBase = self.formatter_string_to_obj(self.formatter) - self.set_formatter(formatter) + def formatter_obj_validator(self) -> StructuredPromptTemplate: + formatter_obj = self.formatter_string_to_obj(self.formatter) + self.set_formatter(formatter_obj) return self @staticmethod @@ -102,8 +96,7 @@ def formatter_string_to_obj(formatter: str) -> PromptFormatterBase: formatter_cls = FORMATTER_REGISTRY[formatter] return formatter_cls() - def set_formatter(self, formatter: PromptFormatterBase) -> None: - self._formatter = formatter - self.system.set_formatter(self._formatter) - if self.user: - self.user.set_formatter(self._formatter) + +class StructuredPromptPair(PromptPairBase): + system: StructuredPromptTemplate + user: Optional[StructuredPromptTemplate] = None diff --git a/tests/data/prompts/structured_prompt.yaml b/tests/data/prompts/structured_prompt.yaml index 9916b0a0..9e690dda 100644 --- a/tests/data/prompts/structured_prompt.yaml +++ b/tests/data/prompts/structured_prompt.yaml @@ -16,9 +16,10 @@ prompt: - name: Hints content: | Answer the question in a concise manner. + formatter: XML user: sections: - name: Question content: | What's the capital of France? - formatter: XML + formatter: XML diff --git a/tests/unit/core/prompt/test_prompt.py b/tests/unit/core/prompt/test_prompt.py index a96f4d99..3cc4bf66 100644 --- a/tests/unit/core/prompt/test_prompt.py +++ b/tests/unit/core/prompt/test_prompt.py @@ -1,4 +1,6 @@ import pytest +from pydantic import ValidationError + from alphaswarm.core.prompt import PromptConfig from alphaswarm.core.prompt.prompt import ( PromptTemplate, @@ -73,8 +75,8 @@ def test_prompt_config_invalid_kind(self) -> None: system_prompt = PromptTemplate(template="You are a helpful assistant.") prompt_pair = PromptPair(system=system_prompt) - with pytest.raises(ValueError, match="Invalid kind: InvalidKind"): - PromptConfig(kind="InvalidKind", prompt=prompt_pair) + with pytest.raises(ValidationError): + PromptConfig(kind="InvalidKind", prompt=prompt_pair) # type: ignore def test_from_dict(self) -> None: data = { From eaa61ca067d07ad759ce067c23782458d3e50f40 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 12 Mar 2025 15:57:54 -0400 Subject: [PATCH 09/13] Add tests for structured prompt --- alphaswarm/core/prompt/structured.py | 8 +- tests/unit/core/prompt/test_prompt.py | 231 +++++++++++++++++++++++++- 2 files changed, 229 insertions(+), 10 deletions(-) diff --git a/alphaswarm/core/prompt/structured.py b/alphaswarm/core/prompt/structured.py index 7e0ea728..efcf95bd 100644 --- a/alphaswarm/core/prompt/structured.py +++ b/alphaswarm/core/prompt/structured.py @@ -37,11 +37,11 @@ def _format_section(self, section: PromptSection) -> str: class MarkdownPromptFormatter(PromptFormatterBase): def _format_section(self, section: PromptSection, indent: int = 1) -> str: - parts = [f"{'#' * indent} {section.name}", ""] + parts = ["", f"{'#' * indent} {section.name}", ""] if section.content: - parts.extend([section.content, ""]) + parts.append(section.content) parts.extend([self._format_section(sec, indent + 1) for sec in section.sections]) - return "\n".join(parts) + return "\n".join(parts).strip() class XMLPromptFormatter(PromptFormatterBase): @@ -71,7 +71,7 @@ def _format_section(self, section: PromptSection, indent: str = "") -> str: class StructuredPromptTemplate(PromptTemplateBase): sections: List[PromptSection] formatter: str = "string" - _formatter_obj: PromptFormatterBase + _formatter_obj: PromptFormatterBase = StringPromptFormatter() # default for mypy, will be overridden def get_template(self) -> str: return self._formatter_obj.format(self.sections) diff --git a/tests/unit/core/prompt/test_prompt.py b/tests/unit/core/prompt/test_prompt.py index 3cc4bf66..5671c967 100644 --- a/tests/unit/core/prompt/test_prompt.py +++ b/tests/unit/core/prompt/test_prompt.py @@ -7,14 +7,166 @@ PromptPair, LLMConfig, ) +from alphaswarm.core.prompt.structured import ( + PromptSection, + StructuredPromptTemplate, + StringPromptFormatter, + MarkdownPromptFormatter, + XMLPromptFormatter, + FORMATTER_REGISTRY, + StructuredPromptPair, +) from tests import PromptPath class TestPromptTemplate: - def test_prompt_template(self) -> None: - template = "This is a test template with {variable}" + def test_get_template(self) -> None: + template = "Test template with {variable}" + prompt = PromptTemplate(template=template) + assert prompt.get_template() == template + + def test_whitespace_stripping(self) -> None: + template = " Template with whitespace " prompt = PromptTemplate(template=template) - assert prompt.template == template + assert prompt.template == "Template with whitespace" + assert prompt.get_template() == "Template with whitespace" + + +class TestStructuredPromptTemplate: + def test_structured_prompt_template_with_string_formatter(self) -> None: + sections = [ + PromptSection(name="Introduction", content="This is an introduction"), + PromptSection(name="Instructions", content="Follow these instructions"), + ] + prompt = StructuredPromptTemplate(sections=sections, formatter="string") + expected = "\n".join(["Introduction", "This is an introduction", "Instructions", "Follow these instructions"]) + assert prompt.get_template() == expected + assert isinstance(prompt._formatter_obj, StringPromptFormatter) + + def test_structured_prompt_template_with_markdown_formatter(self) -> None: + sections = [ + PromptSection(name="Introduction", content="This is an introduction"), + PromptSection(name="Instructions", content="Follow these instructions"), + ] + prompt = StructuredPromptTemplate(sections=sections, formatter="markdown") + expected = "\n".join( + ["# Introduction", "", "This is an introduction", "# Instructions", "", "Follow these instructions"] + ) + assert prompt.get_template() == expected + assert isinstance(prompt._formatter_obj, MarkdownPromptFormatter) + + def test_structured_prompt_template_with_xml_formatter(self) -> None: + sections = [ + PromptSection(name="Introduction", content="This is an introduction"), + PromptSection(name="Instructions", content="Follow these instructions"), + ] + prompt = StructuredPromptTemplate(sections=sections, formatter="xml") + expected = "\n".join( + [ + "", + " This is an introduction", + "", + "", + " Follow these instructions", + "", + ] + ) + assert prompt.get_template() == expected + assert isinstance(prompt._formatter_obj, XMLPromptFormatter) + + def test_nested_sections(self) -> None: + sections = [ + PromptSection( + name="Main Section", + content="Main content", + sections=[PromptSection(name="Subsection", content="Subsection content")], + ) + ] + + # Test with string formatter + prompt = StructuredPromptTemplate(sections=sections, formatter="string") + expected_string = "\n".join(["Main Section", "Main content", "Subsection", "Subsection content"]) + assert prompt.get_template() == expected_string + + # Test with markdown formatter + prompt.set_formatter(MarkdownPromptFormatter()) + expected_md = "\n".join(["# Main Section", "", "Main content", "## Subsection", "", "Subsection content"]) + assert prompt.get_template() == expected_md + + # Test with XML formatter + prompt.set_formatter(XMLPromptFormatter()) + expected_xml = "\n".join( + [ + "", + " Main content", + " ", + " Subsection content", + " ", + "", + ] + ) + assert prompt.get_template() == expected_xml + + def test_invalid_formatter(self) -> None: + sections = [PromptSection(name="Test", content="Test content")] + + with pytest.raises(ValueError, match="Unknown formatter"): + StructuredPromptTemplate(sections=sections, formatter="invalid_formatter") + + def test_set_formatter_manually(self) -> None: + sections = [PromptSection(name="Test", content="Test content")] + prompt = StructuredPromptTemplate(sections=sections, formatter="string") + + # Change formatter after initialization + prompt.set_formatter(XMLPromptFormatter()) + assert isinstance(prompt._formatter_obj, XMLPromptFormatter) + expected = "\n".join(["", " Test content", ""]) + assert prompt.get_template() == expected + + +class TestPromptFormatters: + def test_string_prompt_formatter_with_custom_prefix(self) -> None: + formatter = StringPromptFormatter(section_prefix=">> ") + sections = [ + PromptSection(name="Section1", content="Content1"), + PromptSection(name="Section2", content="Content2"), + ] + expected = "\n".join([">> Section1", "Content1", ">> Section2", "Content2"]) + assert formatter.format(sections) == expected + + def test_markdown_formatter_nested_headings(self) -> None: + formatter = MarkdownPromptFormatter() + sections = [ + PromptSection( + name="Level 1", + content="Content 1", + sections=[ + PromptSection( + name="Level 2", + content="Content 2", + sections=[PromptSection(name="Level 3", content="Content 3")], + ) + ], + ) + ] + expected = "\n".join( + ["# Level 1", "", "Content 1", "## Level 2", "", "Content 2", "### Level 3", "", "Content 3"] + ) + assert formatter.format(sections) == expected + + def test_xml_formatter_multiline_content(self) -> None: + formatter = XMLPromptFormatter() + sections = [PromptSection(name="Section", content="Line 1\nLine 2\nLine 3")] + expected = "\n".join(["
", " Line 1", " Line 2", " Line 3", "
"]) + assert formatter.format(sections) == expected + + def test_formatter_registry(self) -> None: + assert "string" in FORMATTER_REGISTRY + assert "markdown" in FORMATTER_REGISTRY + assert "xml" in FORMATTER_REGISTRY + assert FORMATTER_REGISTRY["string"] == StringPromptFormatter + assert FORMATTER_REGISTRY["markdown"] == MarkdownPromptFormatter + assert FORMATTER_REGISTRY["xml"] == XMLPromptFormatter class TestPromptPair: @@ -32,6 +184,25 @@ def test_prompt_pair_with_system_and_user(self) -> None: assert pair.user == user_prompt +class TestStructuredPromptPair: + def test_structured_prompt_pair_system_only(self) -> None: + system_prompt = StructuredPromptTemplate( + sections=[PromptSection(name="System", content="You are a helpful assistant.")] + ) + pair = StructuredPromptPair(system=system_prompt) + assert pair.system == system_prompt + assert pair.user is None + + def test_structured_prompt_pair_with_system_and_user(self) -> None: + system_prompt = StructuredPromptTemplate( + sections=[PromptSection(name="System", content="You are a helpful assistant.")] + ) + user_prompt = StructuredPromptTemplate(sections=[PromptSection(name="User", content="Help me with this task.")]) + pair = StructuredPromptPair(system=system_prompt, user=user_prompt) + assert pair.system == system_prompt + assert pair.user == user_prompt + + class TestLLMConfig: def test_llm_config_with_model_only(self) -> None: config = LLMConfig(model="gpt-4o") @@ -61,6 +232,20 @@ def test_prompt_config_initialization(self) -> None: assert config.llm is not None assert config.llm.model == "gpt-4o" + def test_prompt_config_with_structured_prompt(self) -> None: + system_prompt = StructuredPromptTemplate( + sections=[PromptSection(name="System", content="You are a helpful assistant.")] + ) + prompt_pair = StructuredPromptPair(system=system_prompt) + + config = PromptConfig( + kind="StructuredPrompt", prompt=prompt_pair, metadata={"version": "1.0"}, llm=LLMConfig(model="gpt-4o") + ) + + assert config.kind == "StructuredPrompt" + assert isinstance(config.prompt, StructuredPromptPair) + assert config.has_llm_config is True + def test_with_empty_metadata_and_llm(self) -> None: prompt_pair = PromptPair( system=PromptTemplate(template="You are a helpful assistant."), @@ -78,6 +263,13 @@ def test_prompt_config_invalid_kind(self) -> None: with pytest.raises(ValidationError): PromptConfig(kind="InvalidKind", prompt=prompt_pair) # type: ignore + def test_mixed_prompt_pair_validation(self) -> None: + system_prompt = PromptTemplate(template="System template") + user_prompt = StructuredPromptTemplate(sections=[PromptSection(name="User", content="User content")]) + + with pytest.raises(ValidationError): + StructuredPromptPair(system=system_prompt, user=user_prompt) # type: ignore + def test_from_dict(self) -> None: data = { "kind": "Prompt", @@ -102,16 +294,43 @@ def test_from_dict(self) -> None: assert config.llm.model == "gpt-4o" assert config.llm.params == {"temperature": 0.7} - def test_from_file(self) -> None: + def test_prompt_from_file(self) -> None: config = PromptConfig.from_yaml(PromptPath.basic) assert config.kind == "Prompt" assert isinstance(config.prompt, PromptPair) assert isinstance(config.prompt.system, PromptTemplate) - assert config.prompt.system.template == "You are a helpful assistant." + assert config.prompt.system.get_template() == "You are a helpful assistant." assert isinstance(config.prompt.user, PromptTemplate) - assert config.prompt.user.template == "Answer the following questions: {question}" + assert config.prompt.user.get_template() == "Answer the following questions: {question}" assert config.metadata == {"description": "This is a prompt doing abc\n"} + assert config.has_llm_config is True assert config.llm is not None assert config.llm.model == "gpt-4o-mini" assert config.llm.params == {"temperature": 0.3} + + def test_structured_prompt_from_file(self) -> None: + config = PromptConfig.from_yaml(PromptPath.structured) + + assert config.kind == "StructuredPrompt" + assert isinstance(config.prompt, StructuredPromptPair) + assert isinstance(config.prompt.system, StructuredPromptTemplate) + assert config.prompt.system.get_template() == "\n".join( + [ + "", + " You are a helpful assistant.", + " ", + " Answer the question in a concise manner.", + " ", + "", + ] + ) + assert isinstance(config.prompt.user, StructuredPromptTemplate) + assert config.prompt.user.get_template() == "\n".join( + ["", " What's the capital of France?", ""] + ) + assert config.metadata == {"description": "This is a prompt doing xyz\n"} + assert config.has_llm_config is True + assert config.llm is not None + assert config.llm.model == "claude-3-5-haiku-20241022" + assert config.llm.params == {"temperature": 0.2} From e1209e99325bbb641ca4b373d757db9885424c58 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 12 Mar 2025 16:09:55 -0400 Subject: [PATCH 10/13] Test tweaks for LLMFunction --- alphaswarm/core/llm/llm_function.py | 2 +- .../core/llm/test_llm_function_from_files.py | 26 +++++++++++++ .../core/llm/test_llm_function_from_files.py | 39 ++++++++++--------- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/alphaswarm/core/llm/llm_function.py b/alphaswarm/core/llm/llm_function.py index e50b0775..3b684c53 100644 --- a/alphaswarm/core/llm/llm_function.py +++ b/alphaswarm/core/llm/llm_function.py @@ -271,7 +271,7 @@ def from_prompt_config( user_prompt_template = prompt_config.prompt.user.get_template() if prompt_config.prompt.user else None if prompt_config.llm is None: - raise ValueError("LLMConfig not set in PromptConfig") + raise ValueError("LLMConfig in PromptConfig is required to create an LLMFunction but was not set") model_id = prompt_config.llm.model # TODO: pass kwargs in the __init__ diff --git a/tests/integration/core/llm/test_llm_function_from_files.py b/tests/integration/core/llm/test_llm_function_from_files.py index 8e019beb..f218f015 100644 --- a/tests/integration/core/llm/test_llm_function_from_files.py +++ b/tests/integration/core/llm/test_llm_function_from_files.py @@ -5,10 +5,16 @@ from pydantic import BaseModel, Field from alphaswarm.core.llm import LLMFunctionTemplated +from alphaswarm.core.prompt import PromptConfig +from tests import PromptPath dotenv.load_dotenv() +class Response(BaseModel): + answer: str = Field(..., description="The answer to the question") + + class SimpleResponse(BaseModel): reasoning: str = Field(..., description="Reasoning behind the response") number: int = Field(..., ge=1, le=10, description="The random number between 1 and 10.") @@ -55,3 +61,23 @@ def test_llm_function_from_user_file() -> None: result = llm_func.execute(user_prompt_params={"min_value": 3, "max_value": 8}) assert isinstance(result, SimpleResponse) assert 3 <= result.number <= 8 + + +def test_llm_function_from_prompt_config() -> None: + llm_func = LLMFunctionTemplated.from_prompt_config( + response_model=Response, + prompt_config=PromptConfig.from_yaml(PromptPath.basic), + ) + + result = llm_func.execute(user_prompt_params={"question": "What's the capital of France?"}) + assert "Paris" in result.answer + + +def test_llm_function_from_structured_prompt_config() -> None: + llm_func = LLMFunctionTemplated.from_prompt_config_file( + response_model=Response, + prompt_config_path=PromptPath.structured, + ) + + result = llm_func.execute() + assert "Paris" in result.answer diff --git a/tests/unit/core/llm/test_llm_function_from_files.py b/tests/unit/core/llm/test_llm_function_from_files.py index a587a5fb..d3410eca 100644 --- a/tests/unit/core/llm/test_llm_function_from_files.py +++ b/tests/unit/core/llm/test_llm_function_from_files.py @@ -41,33 +41,34 @@ def test_from_prompt_config() -> None: assert llm_func_v1._model_id == llm_func_v2._model_id == "gpt-4o-mini" assert llm_func_v1.system_prompt == llm_func_v2.system_prompt == "You are a helpful assistant." - assert ( - llm_func_v1.user_prompt_template - == llm_func_v2.user_prompt_template - == "Answer the following questions: {question}" - ) + expected_user_prompt = "Answer the following questions: {question}" + assert llm_func_v1.user_prompt_template == llm_func_v2.user_prompt_template == expected_user_prompt def test_from_structured_prompt_config() -> None: - # TODO tests prompt llm_func = LLMFunctionTemplated.from_prompt_config_file( response_model=Response, prompt_config_path=PromptPath.structured, ) assert llm_func._model_id == "claude-3-5-haiku-20241022" - assert ( - llm_func.system_prompt - == """ - You are a helpful assistant. - - Answer the question in a concise manner. - -""" + expected_system_prompt = "\n".join( + [ + "", + " You are a helpful assistant.", + " ", + " Answer the question in a concise manner.", + " ", + "", + ] ) - assert ( - llm_func.user_prompt_template - == """ - What's the capital of France? -""" + assert llm_func.system_prompt == expected_system_prompt + + expected_user_prompt = "\n".join( + [ + "", + " What's the capital of France?", + "", + ] ) + assert llm_func.user_prompt_template == expected_user_prompt From 82489025b051d991533c8a08d921fcb01904c9f8 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 12 Mar 2025 16:13:46 -0400 Subject: [PATCH 11/13] Pass kwargs from prompt config into LLMFunctionTemplated --- alphaswarm/core/llm/llm_function.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/alphaswarm/core/llm/llm_function.py b/alphaswarm/core/llm/llm_function.py index 3b684c53..4f48eb8e 100644 --- a/alphaswarm/core/llm/llm_function.py +++ b/alphaswarm/core/llm/llm_function.py @@ -170,6 +170,7 @@ def __init__( user_prompt_template: Optional[str] = None, system_prompt_params: Optional[Dict[str, Any]] = None, max_retries: int = 3, + llm_params: Optional[Dict[str, Any]] = None, ) -> None: """Initialize an LLMFunctionTemplated instance. @@ -180,12 +181,15 @@ def __init__( user_prompt_template: Optional template for the user message system_prompt_params: Parameters for formatting the system prompt if any max_retries: Maximum number of retry attempts + llm_params: Additional keyword arguments to pass to the LLM client """ super().__init__(model_id=model_id, response_model=response_model, max_retries=max_retries) self.system_prompt_template = system_prompt_template self.system_prompt = self._format(system_prompt_template, system_prompt_params) self.user_prompt_template = user_prompt_template + self._llm_params = llm_params or {} + def execute_with_completion( self, user_prompt_params: Optional[Dict[str, Any]] = None, @@ -212,7 +216,7 @@ def execute_with_completion( user_prompt = self._format(self.user_prompt_template, user_prompt_params) messages.append(Message.user(user_prompt)) - return self._execute_with_completion(messages=messages, **kwargs) + return self._execute_with_completion(messages=messages, **self._llm_params, **kwargs) @classmethod def from_files( @@ -273,7 +277,6 @@ def from_prompt_config( if prompt_config.llm is None: raise ValueError("LLMConfig in PromptConfig is required to create an LLMFunction but was not set") model_id = prompt_config.llm.model - # TODO: pass kwargs in the __init__ return cls( model_id=model_id, @@ -282,6 +285,7 @@ def from_prompt_config( user_prompt_template=user_prompt_template, system_prompt_params=system_prompt_params, max_retries=max_retries, + llm_params=prompt_config.llm.params, ) @classmethod From 0128f5f54e219cf24206851cce40d9915366cd8e Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Wed, 12 Mar 2025 16:25:02 -0400 Subject: [PATCH 12/13] Minor cleanup --- tests/unit/core/prompt/test_prompt.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/unit/core/prompt/test_prompt.py b/tests/unit/core/prompt/test_prompt.py index 5671c967..318498e1 100644 --- a/tests/unit/core/prompt/test_prompt.py +++ b/tests/unit/core/prompt/test_prompt.py @@ -83,17 +83,14 @@ def test_nested_sections(self) -> None: ) ] - # Test with string formatter prompt = StructuredPromptTemplate(sections=sections, formatter="string") expected_string = "\n".join(["Main Section", "Main content", "Subsection", "Subsection content"]) assert prompt.get_template() == expected_string - # Test with markdown formatter prompt.set_formatter(MarkdownPromptFormatter()) expected_md = "\n".join(["# Main Section", "", "Main content", "## Subsection", "", "Subsection content"]) assert prompt.get_template() == expected_md - # Test with XML formatter prompt.set_formatter(XMLPromptFormatter()) expected_xml = "\n".join( [ @@ -115,9 +112,8 @@ def test_invalid_formatter(self) -> None: def test_set_formatter_manually(self) -> None: sections = [PromptSection(name="Test", content="Test content")] - prompt = StructuredPromptTemplate(sections=sections, formatter="string") + prompt = StructuredPromptTemplate(sections=sections) - # Change formatter after initialization prompt.set_formatter(XMLPromptFormatter()) assert isinstance(prompt._formatter_obj, XMLPromptFormatter) expected = "\n".join(["", " Test content", ""]) @@ -161,9 +157,6 @@ def test_xml_formatter_multiline_content(self) -> None: assert formatter.format(sections) == expected def test_formatter_registry(self) -> None: - assert "string" in FORMATTER_REGISTRY - assert "markdown" in FORMATTER_REGISTRY - assert "xml" in FORMATTER_REGISTRY assert FORMATTER_REGISTRY["string"] == StringPromptFormatter assert FORMATTER_REGISTRY["markdown"] == MarkdownPromptFormatter assert FORMATTER_REGISTRY["xml"] == XMLPromptFormatter From 9a12bcd1551b47563e8016de19d9e9aa74432a56 Mon Sep 17 00:00:00 2001 From: Dmytro Nikolaiev Date: Thu, 10 Apr 2025 13:27:26 -0400 Subject: [PATCH 13/13] Address comments --- alphaswarm/core/prompt/base.py | 11 +++++++---- alphaswarm/core/prompt/prompt.py | 2 +- alphaswarm/core/prompt/structured.py | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/alphaswarm/core/prompt/base.py b/alphaswarm/core/prompt/base.py index b00d6d3f..e7c56ab7 100644 --- a/alphaswarm/core/prompt/base.py +++ b/alphaswarm/core/prompt/base.py @@ -1,5 +1,5 @@ import abc -from typing import Annotated, Any, Optional +from typing import Annotated, Generic, Optional, TypeVar from pydantic import BaseModel, StringConstraints @@ -13,6 +13,9 @@ def get_template(self) -> str: pass -class PromptPairBase(BaseModel): - system: Any - user: Optional[Any] = None +T = TypeVar("T", bound="BaseModel") + + +class PromptPairBase(BaseModel, Generic[T]): + system: T + user: Optional[T] = None diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py index a994a517..5d32e20d 100644 --- a/alphaswarm/core/prompt/prompt.py +++ b/alphaswarm/core/prompt/prompt.py @@ -16,7 +16,7 @@ def get_template(self) -> str: return self.template -class PromptPair(PromptPairBase): +class PromptPair(PromptPairBase[PromptTemplate]): system: PromptTemplate user: Optional[PromptTemplate] = None diff --git a/alphaswarm/core/prompt/structured.py b/alphaswarm/core/prompt/structured.py index efcf95bd..76309192 100644 --- a/alphaswarm/core/prompt/structured.py +++ b/alphaswarm/core/prompt/structured.py @@ -24,7 +24,7 @@ def _format_section(self, section: PromptSection) -> str: class StringPromptFormatter(PromptFormatterBase): - def __init__(self, section_prefix: str = ""): + def __init__(self, section_prefix: str = "") -> None: self.section_prefix = section_prefix def _format_section(self, section: PromptSection) -> str: @@ -97,6 +97,6 @@ def formatter_string_to_obj(formatter: str) -> PromptFormatterBase: return formatter_cls() -class StructuredPromptPair(PromptPairBase): +class StructuredPromptPair(PromptPairBase[StructuredPromptTemplate]): system: StructuredPromptTemplate user: Optional[StructuredPromptTemplate] = None