diff --git a/alphaswarm/core/llm/llm_function.py b/alphaswarm/core/llm/llm_function.py
index 15452fa2..4f48eb8e 100644
--- a/alphaswarm/core/llm/llm_function.py
+++ b/alphaswarm/core/llm/llm_function.py
@@ -9,6 +9,7 @@
from litellm.types.utils import ModelResponse
from pydantic import BaseModel
+from ..prompt import PromptConfig
from .message import Message
litellm.modify_params = True # for calls with system message only for anthropic
@@ -169,6 +170,7 @@ def __init__(
user_prompt_template: Optional[str] = None,
system_prompt_params: Optional[Dict[str, Any]] = None,
max_retries: int = 3,
+ llm_params: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize an LLMFunctionTemplated instance.
@@ -179,12 +181,15 @@ def __init__(
user_prompt_template: Optional template for the user message
system_prompt_params: Parameters for formatting the system prompt if any
max_retries: Maximum number of retry attempts
+ llm_params: Additional keyword arguments to pass to the LLM client
"""
super().__init__(model_id=model_id, response_model=response_model, max_retries=max_retries)
self.system_prompt_template = system_prompt_template
self.system_prompt = self._format(system_prompt_template, system_prompt_params)
self.user_prompt_template = user_prompt_template
+ self._llm_params = llm_params or {}
+
def execute_with_completion(
self,
user_prompt_params: Optional[Dict[str, Any]] = None,
@@ -211,7 +216,7 @@ def execute_with_completion(
user_prompt = self._format(self.user_prompt_template, user_prompt_params)
messages.append(Message.user(user_prompt))
- return self._execute_with_completion(messages=messages, **kwargs)
+ return self._execute_with_completion(messages=messages, **self._llm_params, **kwargs)
@classmethod
def from_files(
@@ -223,7 +228,7 @@ def from_files(
system_prompt_params: Optional[Dict[str, Any]] = None,
max_retries: int = 3,
) -> LLMFunctionTemplated[T_Response]:
- """Create an instance from template files.
+ """Create an instance from template text files.
Args:
model_id: LiteLLM model ID to use
@@ -250,6 +255,62 @@ def from_files(
max_retries=max_retries,
)
+ @classmethod
+ def from_prompt_config(
+ cls,
+ response_model: Type[T_Response],
+ prompt_config: PromptConfig,
+ system_prompt_params: Optional[Dict[str, Any]] = None,
+ max_retries: int = 3,
+ ) -> LLMFunctionTemplated[T_Response]:
+ """Create an instance from prompt config object.
+
+ Args:
+ response_model: Pydantic model class for structuring responses
+ prompt_config: PromptConfig object
+ system_prompt_params: Parameters for formatting the system prompt
+ max_retries: Maximum number of retry attempts
+ """
+ system_prompt_template = prompt_config.prompt.system.get_template()
+ user_prompt_template = prompt_config.prompt.user.get_template() if prompt_config.prompt.user else None
+
+ if prompt_config.llm is None:
+ raise ValueError("LLMConfig in PromptConfig is required to create an LLMFunction but was not set")
+ model_id = prompt_config.llm.model
+
+ return cls(
+ model_id=model_id,
+ response_model=response_model,
+ system_prompt_template=system_prompt_template,
+ user_prompt_template=user_prompt_template,
+ system_prompt_params=system_prompt_params,
+ max_retries=max_retries,
+ llm_params=prompt_config.llm.params,
+ )
+
+ @classmethod
+ def from_prompt_config_file(
+ cls,
+ response_model: Type[T_Response],
+ prompt_config_path: str,
+ system_prompt_params: Optional[Dict[str, Any]] = None,
+ max_retries: int = 3,
+ ) -> LLMFunctionTemplated[T_Response]:
+ """Create an instance from prompt config file.
+
+ Args:
+ response_model: Pydantic model class for structuring responses
+ prompt_config_path: Path to the prompt config yaml file
+ system_prompt_params: Parameters for formatting the system prompt
+ max_retries: Maximum number of retry attempts
+ """
+ return cls.from_prompt_config(
+ response_model=response_model,
+ prompt_config=PromptConfig.from_yaml(prompt_config_path),
+ system_prompt_params=system_prompt_params,
+ max_retries=max_retries,
+ )
+
@staticmethod
def _format(template: str, params: Optional[Dict[str, Any]] = None) -> str:
"""Format the template string with the given optional parameters."""
diff --git a/alphaswarm/core/prompt/__init__.py b/alphaswarm/core/prompt/__init__.py
new file mode 100644
index 00000000..6434ee7a
--- /dev/null
+++ b/alphaswarm/core/prompt/__init__.py
@@ -0,0 +1,3 @@
+from .prompt import PromptConfig
+
+__all__ = ["PromptConfig"]
diff --git a/alphaswarm/core/prompt/base.py b/alphaswarm/core/prompt/base.py
new file mode 100644
index 00000000..e7c56ab7
--- /dev/null
+++ b/alphaswarm/core/prompt/base.py
@@ -0,0 +1,21 @@
+import abc
+from typing import Annotated, Generic, Optional, TypeVar
+
+from pydantic import BaseModel, StringConstraints
+
+# helper class alias for str that's automatically stripped
+StrippedStr = Annotated[str, StringConstraints(strip_whitespace=True)]
+
+
+class PromptTemplateBase(BaseModel, abc.ABC):
+ @abc.abstractmethod
+ def get_template(self) -> str:
+ pass
+
+
+T = TypeVar("T", bound="BaseModel")
+
+
+class PromptPairBase(BaseModel, Generic[T]):
+ system: T
+ user: Optional[T] = None
diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py
new file mode 100644
index 00000000..5d32e20d
--- /dev/null
+++ b/alphaswarm/core/prompt/prompt.py
@@ -0,0 +1,49 @@
+from __future__ import annotations
+
+from typing import Any, Dict, Literal, Optional, Union
+
+import yaml
+from pydantic import BaseModel
+
+from .base import PromptPairBase, PromptTemplateBase, StrippedStr
+from .structured import StructuredPromptPair
+
+
+class PromptTemplate(PromptTemplateBase):
+ template: StrippedStr
+
+ def get_template(self) -> str:
+ return self.template
+
+
+class PromptPair(PromptPairBase[PromptTemplate]):
+ system: PromptTemplate
+ user: Optional[PromptTemplate] = None
+
+
+class LLMConfig(BaseModel):
+ model: str
+ params: Optional[Dict[str, Any]] = None
+
+
+class PromptConfig(BaseModel):
+ """
+ Prompt configuration object.
+ Contains prompt pair, optional metadata, and optional LLM configuration.
+ If LLM configuration is specified, it could be used to generate an LLMFunction.
+ """
+
+ kind: Literal["Prompt", "StructuredPrompt"]
+ prompt: Union[PromptPair, StructuredPromptPair]
+ metadata: Optional[Dict[str, Any]] = None
+ llm: Optional[LLMConfig] = None
+
+ @property
+ def has_llm_config(self) -> bool:
+ return self.llm is not None
+
+ @classmethod
+ def from_yaml(cls, path: str) -> PromptConfig:
+ with open(path, "r", encoding="utf-8") as f:
+ data = yaml.safe_load(f)
+ return cls(**data)
diff --git a/alphaswarm/core/prompt/structured.py b/alphaswarm/core/prompt/structured.py
new file mode 100644
index 00000000..76309192
--- /dev/null
+++ b/alphaswarm/core/prompt/structured.py
@@ -0,0 +1,102 @@
+from __future__ import annotations
+
+import abc
+from typing import List, Mapping, Optional, Sequence, Type
+
+from pydantic import BaseModel, model_validator
+
+from .base import PromptPairBase, PromptTemplateBase, StrippedStr
+
+
+class PromptSection(BaseModel):
+ name: str
+ content: Optional[StrippedStr] = None
+ sections: List[PromptSection] = []
+
+
+class PromptFormatterBase(abc.ABC):
+ def format(self, sections: Sequence[PromptSection]) -> str:
+ return "\n".join(self._format_section(section) for section in sections)
+
+ @abc.abstractmethod
+ def _format_section(self, section: PromptSection) -> str:
+ pass
+
+
+class StringPromptFormatter(PromptFormatterBase):
+ def __init__(self, section_prefix: str = "") -> None:
+ self.section_prefix = section_prefix
+
+ def _format_section(self, section: PromptSection) -> str:
+ parts = [f"{self.section_prefix}{section.name}"]
+ if section.content:
+ parts.append(section.content)
+ parts.extend([self._format_section(sec) for sec in section.sections])
+ return "\n".join(parts)
+
+
+class MarkdownPromptFormatter(PromptFormatterBase):
+ def _format_section(self, section: PromptSection, indent: int = 1) -> str:
+ parts = ["", f"{'#' * indent} {section.name}", ""]
+ if section.content:
+ parts.append(section.content)
+ parts.extend([self._format_section(sec, indent + 1) for sec in section.sections])
+ return "\n".join(parts).strip()
+
+
+class XMLPromptFormatter(PromptFormatterBase):
+ INDENT_DIFF: str = " "
+
+ def _format_section(self, section: PromptSection, indent: str = "") -> str:
+ name_snake_case = section.name.lower().replace(" ", "_")
+ parts = [f"{indent}<{name_snake_case}>"]
+
+ if section.content:
+ content_lines = section.content.split("\n")
+ content = "\n".join([f"{indent}{self.INDENT_DIFF}{line}" for line in content_lines])
+ parts.append(content)
+
+ parts.extend([self._format_section(sec, indent + self.INDENT_DIFF) for sec in section.sections])
+ parts.append(f"{indent}{name_snake_case}>")
+ return "\n".join(parts)
+
+
+FORMATTER_REGISTRY: Mapping[str, Type[PromptFormatterBase]] = {
+ "string": StringPromptFormatter,
+ "markdown": MarkdownPromptFormatter,
+ "xml": XMLPromptFormatter,
+}
+
+
+class StructuredPromptTemplate(PromptTemplateBase):
+ sections: List[PromptSection]
+ formatter: str = "string"
+ _formatter_obj: PromptFormatterBase = StringPromptFormatter() # default for mypy, will be overridden
+
+ def get_template(self) -> str:
+ return self._formatter_obj.format(self.sections)
+
+ def set_formatter(self, formatter: PromptFormatterBase) -> None:
+ self._formatter_obj = formatter
+
+ @model_validator(mode="after")
+ def formatter_obj_validator(self) -> StructuredPromptTemplate:
+ formatter_obj = self.formatter_string_to_obj(self.formatter)
+ self.set_formatter(formatter_obj)
+ return self
+
+ @staticmethod
+ def formatter_string_to_obj(formatter: str) -> PromptFormatterBase:
+ formatter = formatter.lower()
+ if formatter not in FORMATTER_REGISTRY:
+ raise ValueError(
+ f"Unknown formatter: `{formatter}`. Available formatters: {', '.join(FORMATTER_REGISTRY.keys())}"
+ )
+
+ formatter_cls = FORMATTER_REGISTRY[formatter]
+ return formatter_cls()
+
+
+class StructuredPromptPair(PromptPairBase[StructuredPromptTemplate]):
+ system: StructuredPromptTemplate
+ user: Optional[StructuredPromptTemplate] = None
diff --git a/tests/__init__.py b/tests/__init__.py
index 4eb9068d..bde59cb5 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,5 +1,14 @@
import os
+from typing import Final
+from enum import Enum
+
+DATA_PATH: Final[str] = os.path.join(os.path.dirname(__file__), "data")
+
+
+class PromptPath(str, Enum):
+ basic = os.path.join(DATA_PATH, "prompts", "prompt.yaml")
+ structured = os.path.join(DATA_PATH, "prompts", "structured_prompt.yaml")
def get_data_filename(filename: str) -> str:
- return os.path.join(os.path.dirname(__file__), "data", filename)
+ return os.path.join(DATA_PATH, filename)
diff --git a/tests/data/prompts/prompt.yaml b/tests/data/prompts/prompt.yaml
new file mode 100644
index 00000000..eabd3803
--- /dev/null
+++ b/tests/data/prompts/prompt.yaml
@@ -0,0 +1,15 @@
+kind: Prompt
+metadata:
+ description: |
+ This is a prompt doing abc
+llm:
+ model: gpt-4o-mini
+ params:
+ temperature: 0.3
+prompt:
+ system:
+ template: |
+ You are a helpful assistant.
+ user:
+ template: |
+ Answer the following questions: {question}
diff --git a/tests/data/prompts/structured_prompt.yaml b/tests/data/prompts/structured_prompt.yaml
new file mode 100644
index 00000000..9e690dda
--- /dev/null
+++ b/tests/data/prompts/structured_prompt.yaml
@@ -0,0 +1,25 @@
+kind: StructuredPrompt
+metadata:
+ description: |
+ This is a prompt doing xyz
+llm:
+ model: claude-3-5-haiku-20241022
+ params:
+ temperature: 0.2
+prompt:
+ system:
+ sections:
+ - name: Instructions
+ content: |
+ You are a helpful assistant.
+ sections:
+ - name: Hints
+ content: |
+ Answer the question in a concise manner.
+ formatter: XML
+ user:
+ sections:
+ - name: Question
+ content: |
+ What's the capital of France?
+ formatter: XML
diff --git a/tests/integration/core/llm/test_llm_function_from_files.py b/tests/integration/core/llm/test_llm_function_from_files.py
index 8e019beb..f218f015 100644
--- a/tests/integration/core/llm/test_llm_function_from_files.py
+++ b/tests/integration/core/llm/test_llm_function_from_files.py
@@ -5,10 +5,16 @@
from pydantic import BaseModel, Field
from alphaswarm.core.llm import LLMFunctionTemplated
+from alphaswarm.core.prompt import PromptConfig
+from tests import PromptPath
dotenv.load_dotenv()
+class Response(BaseModel):
+ answer: str = Field(..., description="The answer to the question")
+
+
class SimpleResponse(BaseModel):
reasoning: str = Field(..., description="Reasoning behind the response")
number: int = Field(..., ge=1, le=10, description="The random number between 1 and 10.")
@@ -55,3 +61,23 @@ def test_llm_function_from_user_file() -> None:
result = llm_func.execute(user_prompt_params={"min_value": 3, "max_value": 8})
assert isinstance(result, SimpleResponse)
assert 3 <= result.number <= 8
+
+
+def test_llm_function_from_prompt_config() -> None:
+ llm_func = LLMFunctionTemplated.from_prompt_config(
+ response_model=Response,
+ prompt_config=PromptConfig.from_yaml(PromptPath.basic),
+ )
+
+ result = llm_func.execute(user_prompt_params={"question": "What's the capital of France?"})
+ assert "Paris" in result.answer
+
+
+def test_llm_function_from_structured_prompt_config() -> None:
+ llm_func = LLMFunctionTemplated.from_prompt_config_file(
+ response_model=Response,
+ prompt_config_path=PromptPath.structured,
+ )
+
+ result = llm_func.execute()
+ assert "Paris" in result.answer
diff --git a/tests/unit/core/llm/test_llm_function_from_files.py b/tests/unit/core/llm/test_llm_function_from_files.py
index 856ada83..d3410eca 100644
--- a/tests/unit/core/llm/test_llm_function_from_files.py
+++ b/tests/unit/core/llm/test_llm_function_from_files.py
@@ -1,30 +1,74 @@
import tempfile
-from typing import Any
import pytest
from pydantic import BaseModel
from alphaswarm.core.llm import LLMFunctionTemplated
+from alphaswarm.core.prompt import PromptConfig
+
+from tests import PromptPath
class Response(BaseModel):
test: str
-def get_sample_llm_function_from_files(**kwargs: Any) -> LLMFunctionTemplated[Response]:
- return LLMFunctionTemplated.from_files(
- model_id="test",
- response_model=Response,
- **kwargs,
- )
-
-
def test_execute_with_user_prompt_params_but_no_template_raises() -> None:
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".txt", delete=True) as system_file:
system_file.write("Sample system prompt")
system_file.flush()
- llm_func = get_sample_llm_function_from_files(system_prompt_path=system_file.name)
+ llm_func = LLMFunctionTemplated.from_files(
+ model_id="test",
+ response_model=Response,
+ system_prompt_path=system_file.name,
+ )
with pytest.raises(ValueError, match="User prompt params provided but no user prompt template exists"):
llm_func.execute(user_prompt_params={"test": "value"})
+
+
+def test_from_prompt_config() -> None:
+ llm_func_v1 = LLMFunctionTemplated.from_prompt_config(
+ response_model=Response,
+ prompt_config=PromptConfig.from_yaml(PromptPath.basic),
+ )
+
+ llm_func_v2 = LLMFunctionTemplated.from_prompt_config_file(
+ response_model=Response,
+ prompt_config_path=PromptPath.basic,
+ )
+
+ assert llm_func_v1._model_id == llm_func_v2._model_id == "gpt-4o-mini"
+ assert llm_func_v1.system_prompt == llm_func_v2.system_prompt == "You are a helpful assistant."
+ expected_user_prompt = "Answer the following questions: {question}"
+ assert llm_func_v1.user_prompt_template == llm_func_v2.user_prompt_template == expected_user_prompt
+
+
+def test_from_structured_prompt_config() -> None:
+ llm_func = LLMFunctionTemplated.from_prompt_config_file(
+ response_model=Response,
+ prompt_config_path=PromptPath.structured,
+ )
+
+ assert llm_func._model_id == "claude-3-5-haiku-20241022"
+ expected_system_prompt = "\n".join(
+ [
+ "",
+ " You are a helpful assistant.",
+ " ",
+ " Answer the question in a concise manner.",
+ " ",
+ "",
+ ]
+ )
+ assert llm_func.system_prompt == expected_system_prompt
+
+ expected_user_prompt = "\n".join(
+ [
+ "",
+ " What's the capital of France?",
+ "",
+ ]
+ )
+ assert llm_func.user_prompt_template == expected_user_prompt
diff --git a/tests/unit/core/prompt/__init__.py b/tests/unit/core/prompt/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/unit/core/prompt/test_prompt.py b/tests/unit/core/prompt/test_prompt.py
new file mode 100644
index 00000000..318498e1
--- /dev/null
+++ b/tests/unit/core/prompt/test_prompt.py
@@ -0,0 +1,329 @@
+import pytest
+from pydantic import ValidationError
+
+from alphaswarm.core.prompt import PromptConfig
+from alphaswarm.core.prompt.prompt import (
+ PromptTemplate,
+ PromptPair,
+ LLMConfig,
+)
+from alphaswarm.core.prompt.structured import (
+ PromptSection,
+ StructuredPromptTemplate,
+ StringPromptFormatter,
+ MarkdownPromptFormatter,
+ XMLPromptFormatter,
+ FORMATTER_REGISTRY,
+ StructuredPromptPair,
+)
+from tests import PromptPath
+
+
+class TestPromptTemplate:
+ def test_get_template(self) -> None:
+ template = "Test template with {variable}"
+ prompt = PromptTemplate(template=template)
+ assert prompt.get_template() == template
+
+ def test_whitespace_stripping(self) -> None:
+ template = " Template with whitespace "
+ prompt = PromptTemplate(template=template)
+ assert prompt.template == "Template with whitespace"
+ assert prompt.get_template() == "Template with whitespace"
+
+
+class TestStructuredPromptTemplate:
+ def test_structured_prompt_template_with_string_formatter(self) -> None:
+ sections = [
+ PromptSection(name="Introduction", content="This is an introduction"),
+ PromptSection(name="Instructions", content="Follow these instructions"),
+ ]
+ prompt = StructuredPromptTemplate(sections=sections, formatter="string")
+ expected = "\n".join(["Introduction", "This is an introduction", "Instructions", "Follow these instructions"])
+ assert prompt.get_template() == expected
+ assert isinstance(prompt._formatter_obj, StringPromptFormatter)
+
+ def test_structured_prompt_template_with_markdown_formatter(self) -> None:
+ sections = [
+ PromptSection(name="Introduction", content="This is an introduction"),
+ PromptSection(name="Instructions", content="Follow these instructions"),
+ ]
+ prompt = StructuredPromptTemplate(sections=sections, formatter="markdown")
+ expected = "\n".join(
+ ["# Introduction", "", "This is an introduction", "# Instructions", "", "Follow these instructions"]
+ )
+ assert prompt.get_template() == expected
+ assert isinstance(prompt._formatter_obj, MarkdownPromptFormatter)
+
+ def test_structured_prompt_template_with_xml_formatter(self) -> None:
+ sections = [
+ PromptSection(name="Introduction", content="This is an introduction"),
+ PromptSection(name="Instructions", content="Follow these instructions"),
+ ]
+ prompt = StructuredPromptTemplate(sections=sections, formatter="xml")
+ expected = "\n".join(
+ [
+ "",
+ " This is an introduction",
+ "",
+ "",
+ " Follow these instructions",
+ "",
+ ]
+ )
+ assert prompt.get_template() == expected
+ assert isinstance(prompt._formatter_obj, XMLPromptFormatter)
+
+ def test_nested_sections(self) -> None:
+ sections = [
+ PromptSection(
+ name="Main Section",
+ content="Main content",
+ sections=[PromptSection(name="Subsection", content="Subsection content")],
+ )
+ ]
+
+ prompt = StructuredPromptTemplate(sections=sections, formatter="string")
+ expected_string = "\n".join(["Main Section", "Main content", "Subsection", "Subsection content"])
+ assert prompt.get_template() == expected_string
+
+ prompt.set_formatter(MarkdownPromptFormatter())
+ expected_md = "\n".join(["# Main Section", "", "Main content", "## Subsection", "", "Subsection content"])
+ assert prompt.get_template() == expected_md
+
+ prompt.set_formatter(XMLPromptFormatter())
+ expected_xml = "\n".join(
+ [
+ "",
+ " Main content",
+ " ",
+ " Subsection content",
+ " ",
+ "",
+ ]
+ )
+ assert prompt.get_template() == expected_xml
+
+ def test_invalid_formatter(self) -> None:
+ sections = [PromptSection(name="Test", content="Test content")]
+
+ with pytest.raises(ValueError, match="Unknown formatter"):
+ StructuredPromptTemplate(sections=sections, formatter="invalid_formatter")
+
+ def test_set_formatter_manually(self) -> None:
+ sections = [PromptSection(name="Test", content="Test content")]
+ prompt = StructuredPromptTemplate(sections=sections)
+
+ prompt.set_formatter(XMLPromptFormatter())
+ assert isinstance(prompt._formatter_obj, XMLPromptFormatter)
+ expected = "\n".join(["", " Test content", ""])
+ assert prompt.get_template() == expected
+
+
+class TestPromptFormatters:
+ def test_string_prompt_formatter_with_custom_prefix(self) -> None:
+ formatter = StringPromptFormatter(section_prefix=">> ")
+ sections = [
+ PromptSection(name="Section1", content="Content1"),
+ PromptSection(name="Section2", content="Content2"),
+ ]
+ expected = "\n".join([">> Section1", "Content1", ">> Section2", "Content2"])
+ assert formatter.format(sections) == expected
+
+ def test_markdown_formatter_nested_headings(self) -> None:
+ formatter = MarkdownPromptFormatter()
+ sections = [
+ PromptSection(
+ name="Level 1",
+ content="Content 1",
+ sections=[
+ PromptSection(
+ name="Level 2",
+ content="Content 2",
+ sections=[PromptSection(name="Level 3", content="Content 3")],
+ )
+ ],
+ )
+ ]
+ expected = "\n".join(
+ ["# Level 1", "", "Content 1", "## Level 2", "", "Content 2", "### Level 3", "", "Content 3"]
+ )
+ assert formatter.format(sections) == expected
+
+ def test_xml_formatter_multiline_content(self) -> None:
+ formatter = XMLPromptFormatter()
+ sections = [PromptSection(name="Section", content="Line 1\nLine 2\nLine 3")]
+ expected = "\n".join(["", " Line 1", " Line 2", " Line 3", ""])
+ assert formatter.format(sections) == expected
+
+ def test_formatter_registry(self) -> None:
+ assert FORMATTER_REGISTRY["string"] == StringPromptFormatter
+ assert FORMATTER_REGISTRY["markdown"] == MarkdownPromptFormatter
+ assert FORMATTER_REGISTRY["xml"] == XMLPromptFormatter
+
+
+class TestPromptPair:
+ def test_prompt_pair_with_system_only(self) -> None:
+ system_prompt = PromptTemplate(template="You are a helpful assistant.")
+ pair = PromptPair(system=system_prompt)
+ assert pair.system == system_prompt
+ assert pair.user is None
+
+ def test_prompt_pair_with_system_and_user(self) -> None:
+ system_prompt = PromptTemplate(template="You are a helpful assistant.")
+ user_prompt = PromptTemplate(template="Help me with {task}.")
+ pair = PromptPair(system=system_prompt, user=user_prompt)
+ assert pair.system == system_prompt
+ assert pair.user == user_prompt
+
+
+class TestStructuredPromptPair:
+ def test_structured_prompt_pair_system_only(self) -> None:
+ system_prompt = StructuredPromptTemplate(
+ sections=[PromptSection(name="System", content="You are a helpful assistant.")]
+ )
+ pair = StructuredPromptPair(system=system_prompt)
+ assert pair.system == system_prompt
+ assert pair.user is None
+
+ def test_structured_prompt_pair_with_system_and_user(self) -> None:
+ system_prompt = StructuredPromptTemplate(
+ sections=[PromptSection(name="System", content="You are a helpful assistant.")]
+ )
+ user_prompt = StructuredPromptTemplate(sections=[PromptSection(name="User", content="Help me with this task.")])
+ pair = StructuredPromptPair(system=system_prompt, user=user_prompt)
+ assert pair.system == system_prompt
+ assert pair.user == user_prompt
+
+
+class TestLLMConfig:
+ def test_llm_config_with_model_only(self) -> None:
+ config = LLMConfig(model="gpt-4o")
+ assert config.model == "gpt-4o"
+ assert config.params is None
+
+ def test_llm_config_with_params(self) -> None:
+ params = {"temperature": 0.7, "max_tokens": 100, "another_param": "value"}
+ config = LLMConfig(model="gpt-4o", params=params)
+ assert config.model == "gpt-4o"
+ assert config.params == params
+
+
+class TestPromptConfig:
+ def test_prompt_config_initialization(self) -> None:
+ prompt_pair = PromptPair(
+ system=PromptTemplate(template="You are a helpful assistant."),
+ user=PromptTemplate(template="Help me with {task}."),
+ )
+ metadata = {"version": "1.0", "author": "John Doe", "created_at": "2025-02-25"}
+
+ config = PromptConfig(kind="Prompt", prompt=prompt_pair, metadata=metadata, llm=LLMConfig(model="gpt-4o"))
+
+ assert config.kind == "Prompt"
+ assert config.prompt == prompt_pair
+ assert config.metadata == metadata
+ assert config.llm is not None
+ assert config.llm.model == "gpt-4o"
+
+ def test_prompt_config_with_structured_prompt(self) -> None:
+ system_prompt = StructuredPromptTemplate(
+ sections=[PromptSection(name="System", content="You are a helpful assistant.")]
+ )
+ prompt_pair = StructuredPromptPair(system=system_prompt)
+
+ config = PromptConfig(
+ kind="StructuredPrompt", prompt=prompt_pair, metadata={"version": "1.0"}, llm=LLMConfig(model="gpt-4o")
+ )
+
+ assert config.kind == "StructuredPrompt"
+ assert isinstance(config.prompt, StructuredPromptPair)
+ assert config.has_llm_config is True
+
+ def test_with_empty_metadata_and_llm(self) -> None:
+ prompt_pair = PromptPair(
+ system=PromptTemplate(template="You are a helpful assistant."),
+ user=PromptTemplate(template="Help me with {task}."),
+ )
+ config = PromptConfig(kind="Prompt", prompt=prompt_pair)
+
+ assert config.metadata is None
+ assert config.llm is None
+
+ def test_prompt_config_invalid_kind(self) -> None:
+ system_prompt = PromptTemplate(template="You are a helpful assistant.")
+ prompt_pair = PromptPair(system=system_prompt)
+
+ with pytest.raises(ValidationError):
+ PromptConfig(kind="InvalidKind", prompt=prompt_pair) # type: ignore
+
+ def test_mixed_prompt_pair_validation(self) -> None:
+ system_prompt = PromptTemplate(template="System template")
+ user_prompt = StructuredPromptTemplate(sections=[PromptSection(name="User", content="User content")])
+
+ with pytest.raises(ValidationError):
+ StructuredPromptPair(system=system_prompt, user=user_prompt) # type: ignore
+
+ def test_from_dict(self) -> None:
+ data = {
+ "kind": "Prompt",
+ "prompt": {
+ "system": {"template": "You are a helpful assistant."},
+ "user": {"template": "Help me with this task."},
+ },
+ "metadata": {"version": "0.0.1"},
+ "llm": {"model": "gpt-4o", "params": {"temperature": 0.7}},
+ }
+
+ config = PromptConfig(**data) # type: ignore
+
+ assert config.kind == "Prompt"
+ assert isinstance(config.prompt, PromptPair)
+ assert isinstance(config.prompt.system, PromptTemplate)
+ assert config.prompt.system.template == "You are a helpful assistant."
+ assert isinstance(config.prompt.user, PromptTemplate)
+ assert config.prompt.user.template == "Help me with this task."
+ assert config.metadata == {"version": "0.0.1"}
+ assert config.llm is not None
+ assert config.llm.model == "gpt-4o"
+ assert config.llm.params == {"temperature": 0.7}
+
+ def test_prompt_from_file(self) -> None:
+ config = PromptConfig.from_yaml(PromptPath.basic)
+
+ assert config.kind == "Prompt"
+ assert isinstance(config.prompt, PromptPair)
+ assert isinstance(config.prompt.system, PromptTemplate)
+ assert config.prompt.system.get_template() == "You are a helpful assistant."
+ assert isinstance(config.prompt.user, PromptTemplate)
+ assert config.prompt.user.get_template() == "Answer the following questions: {question}"
+ assert config.metadata == {"description": "This is a prompt doing abc\n"}
+ assert config.has_llm_config is True
+ assert config.llm is not None
+ assert config.llm.model == "gpt-4o-mini"
+ assert config.llm.params == {"temperature": 0.3}
+
+ def test_structured_prompt_from_file(self) -> None:
+ config = PromptConfig.from_yaml(PromptPath.structured)
+
+ assert config.kind == "StructuredPrompt"
+ assert isinstance(config.prompt, StructuredPromptPair)
+ assert isinstance(config.prompt.system, StructuredPromptTemplate)
+ assert config.prompt.system.get_template() == "\n".join(
+ [
+ "",
+ " You are a helpful assistant.",
+ " ",
+ " Answer the question in a concise manner.",
+ " ",
+ "",
+ ]
+ )
+ assert isinstance(config.prompt.user, StructuredPromptTemplate)
+ assert config.prompt.user.get_template() == "\n".join(
+ ["", " What's the capital of France?", ""]
+ )
+ assert config.metadata == {"description": "This is a prompt doing xyz\n"}
+ assert config.has_llm_config is True
+ assert config.llm is not None
+ assert config.llm.model == "claude-3-5-haiku-20241022"
+ assert config.llm.params == {"temperature": 0.2}