diff --git a/alphaswarm/core/llm/llm_function.py b/alphaswarm/core/llm/llm_function.py index 15452fa2..4f48eb8e 100644 --- a/alphaswarm/core/llm/llm_function.py +++ b/alphaswarm/core/llm/llm_function.py @@ -9,6 +9,7 @@ from litellm.types.utils import ModelResponse from pydantic import BaseModel +from ..prompt import PromptConfig from .message import Message litellm.modify_params = True # for calls with system message only for anthropic @@ -169,6 +170,7 @@ def __init__( user_prompt_template: Optional[str] = None, system_prompt_params: Optional[Dict[str, Any]] = None, max_retries: int = 3, + llm_params: Optional[Dict[str, Any]] = None, ) -> None: """Initialize an LLMFunctionTemplated instance. @@ -179,12 +181,15 @@ def __init__( user_prompt_template: Optional template for the user message system_prompt_params: Parameters for formatting the system prompt if any max_retries: Maximum number of retry attempts + llm_params: Additional keyword arguments to pass to the LLM client """ super().__init__(model_id=model_id, response_model=response_model, max_retries=max_retries) self.system_prompt_template = system_prompt_template self.system_prompt = self._format(system_prompt_template, system_prompt_params) self.user_prompt_template = user_prompt_template + self._llm_params = llm_params or {} + def execute_with_completion( self, user_prompt_params: Optional[Dict[str, Any]] = None, @@ -211,7 +216,7 @@ def execute_with_completion( user_prompt = self._format(self.user_prompt_template, user_prompt_params) messages.append(Message.user(user_prompt)) - return self._execute_with_completion(messages=messages, **kwargs) + return self._execute_with_completion(messages=messages, **self._llm_params, **kwargs) @classmethod def from_files( @@ -223,7 +228,7 @@ def from_files( system_prompt_params: Optional[Dict[str, Any]] = None, max_retries: int = 3, ) -> LLMFunctionTemplated[T_Response]: - """Create an instance from template files. + """Create an instance from template text files. Args: model_id: LiteLLM model ID to use @@ -250,6 +255,62 @@ def from_files( max_retries=max_retries, ) + @classmethod + def from_prompt_config( + cls, + response_model: Type[T_Response], + prompt_config: PromptConfig, + system_prompt_params: Optional[Dict[str, Any]] = None, + max_retries: int = 3, + ) -> LLMFunctionTemplated[T_Response]: + """Create an instance from prompt config object. + + Args: + response_model: Pydantic model class for structuring responses + prompt_config: PromptConfig object + system_prompt_params: Parameters for formatting the system prompt + max_retries: Maximum number of retry attempts + """ + system_prompt_template = prompt_config.prompt.system.get_template() + user_prompt_template = prompt_config.prompt.user.get_template() if prompt_config.prompt.user else None + + if prompt_config.llm is None: + raise ValueError("LLMConfig in PromptConfig is required to create an LLMFunction but was not set") + model_id = prompt_config.llm.model + + return cls( + model_id=model_id, + response_model=response_model, + system_prompt_template=system_prompt_template, + user_prompt_template=user_prompt_template, + system_prompt_params=system_prompt_params, + max_retries=max_retries, + llm_params=prompt_config.llm.params, + ) + + @classmethod + def from_prompt_config_file( + cls, + response_model: Type[T_Response], + prompt_config_path: str, + system_prompt_params: Optional[Dict[str, Any]] = None, + max_retries: int = 3, + ) -> LLMFunctionTemplated[T_Response]: + """Create an instance from prompt config file. + + Args: + response_model: Pydantic model class for structuring responses + prompt_config_path: Path to the prompt config yaml file + system_prompt_params: Parameters for formatting the system prompt + max_retries: Maximum number of retry attempts + """ + return cls.from_prompt_config( + response_model=response_model, + prompt_config=PromptConfig.from_yaml(prompt_config_path), + system_prompt_params=system_prompt_params, + max_retries=max_retries, + ) + @staticmethod def _format(template: str, params: Optional[Dict[str, Any]] = None) -> str: """Format the template string with the given optional parameters.""" diff --git a/alphaswarm/core/prompt/__init__.py b/alphaswarm/core/prompt/__init__.py new file mode 100644 index 00000000..6434ee7a --- /dev/null +++ b/alphaswarm/core/prompt/__init__.py @@ -0,0 +1,3 @@ +from .prompt import PromptConfig + +__all__ = ["PromptConfig"] diff --git a/alphaswarm/core/prompt/base.py b/alphaswarm/core/prompt/base.py new file mode 100644 index 00000000..e7c56ab7 --- /dev/null +++ b/alphaswarm/core/prompt/base.py @@ -0,0 +1,21 @@ +import abc +from typing import Annotated, Generic, Optional, TypeVar + +from pydantic import BaseModel, StringConstraints + +# helper class alias for str that's automatically stripped +StrippedStr = Annotated[str, StringConstraints(strip_whitespace=True)] + + +class PromptTemplateBase(BaseModel, abc.ABC): + @abc.abstractmethod + def get_template(self) -> str: + pass + + +T = TypeVar("T", bound="BaseModel") + + +class PromptPairBase(BaseModel, Generic[T]): + system: T + user: Optional[T] = None diff --git a/alphaswarm/core/prompt/prompt.py b/alphaswarm/core/prompt/prompt.py new file mode 100644 index 00000000..5d32e20d --- /dev/null +++ b/alphaswarm/core/prompt/prompt.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import Any, Dict, Literal, Optional, Union + +import yaml +from pydantic import BaseModel + +from .base import PromptPairBase, PromptTemplateBase, StrippedStr +from .structured import StructuredPromptPair + + +class PromptTemplate(PromptTemplateBase): + template: StrippedStr + + def get_template(self) -> str: + return self.template + + +class PromptPair(PromptPairBase[PromptTemplate]): + system: PromptTemplate + user: Optional[PromptTemplate] = None + + +class LLMConfig(BaseModel): + model: str + params: Optional[Dict[str, Any]] = None + + +class PromptConfig(BaseModel): + """ + Prompt configuration object. + Contains prompt pair, optional metadata, and optional LLM configuration. + If LLM configuration is specified, it could be used to generate an LLMFunction. + """ + + kind: Literal["Prompt", "StructuredPrompt"] + prompt: Union[PromptPair, StructuredPromptPair] + metadata: Optional[Dict[str, Any]] = None + llm: Optional[LLMConfig] = None + + @property + def has_llm_config(self) -> bool: + return self.llm is not None + + @classmethod + def from_yaml(cls, path: str) -> PromptConfig: + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + return cls(**data) diff --git a/alphaswarm/core/prompt/structured.py b/alphaswarm/core/prompt/structured.py new file mode 100644 index 00000000..76309192 --- /dev/null +++ b/alphaswarm/core/prompt/structured.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import abc +from typing import List, Mapping, Optional, Sequence, Type + +from pydantic import BaseModel, model_validator + +from .base import PromptPairBase, PromptTemplateBase, StrippedStr + + +class PromptSection(BaseModel): + name: str + content: Optional[StrippedStr] = None + sections: List[PromptSection] = [] + + +class PromptFormatterBase(abc.ABC): + def format(self, sections: Sequence[PromptSection]) -> str: + return "\n".join(self._format_section(section) for section in sections) + + @abc.abstractmethod + def _format_section(self, section: PromptSection) -> str: + pass + + +class StringPromptFormatter(PromptFormatterBase): + def __init__(self, section_prefix: str = "") -> None: + self.section_prefix = section_prefix + + def _format_section(self, section: PromptSection) -> str: + parts = [f"{self.section_prefix}{section.name}"] + if section.content: + parts.append(section.content) + parts.extend([self._format_section(sec) for sec in section.sections]) + return "\n".join(parts) + + +class MarkdownPromptFormatter(PromptFormatterBase): + def _format_section(self, section: PromptSection, indent: int = 1) -> str: + parts = ["", f"{'#' * indent} {section.name}", ""] + if section.content: + parts.append(section.content) + parts.extend([self._format_section(sec, indent + 1) for sec in section.sections]) + return "\n".join(parts).strip() + + +class XMLPromptFormatter(PromptFormatterBase): + INDENT_DIFF: str = " " + + def _format_section(self, section: PromptSection, indent: str = "") -> str: + name_snake_case = section.name.lower().replace(" ", "_") + parts = [f"{indent}<{name_snake_case}>"] + + if section.content: + content_lines = section.content.split("\n") + content = "\n".join([f"{indent}{self.INDENT_DIFF}{line}" for line in content_lines]) + parts.append(content) + + parts.extend([self._format_section(sec, indent + self.INDENT_DIFF) for sec in section.sections]) + parts.append(f"{indent}") + return "\n".join(parts) + + +FORMATTER_REGISTRY: Mapping[str, Type[PromptFormatterBase]] = { + "string": StringPromptFormatter, + "markdown": MarkdownPromptFormatter, + "xml": XMLPromptFormatter, +} + + +class StructuredPromptTemplate(PromptTemplateBase): + sections: List[PromptSection] + formatter: str = "string" + _formatter_obj: PromptFormatterBase = StringPromptFormatter() # default for mypy, will be overridden + + def get_template(self) -> str: + return self._formatter_obj.format(self.sections) + + def set_formatter(self, formatter: PromptFormatterBase) -> None: + self._formatter_obj = formatter + + @model_validator(mode="after") + def formatter_obj_validator(self) -> StructuredPromptTemplate: + formatter_obj = self.formatter_string_to_obj(self.formatter) + self.set_formatter(formatter_obj) + return self + + @staticmethod + def formatter_string_to_obj(formatter: str) -> PromptFormatterBase: + formatter = formatter.lower() + if formatter not in FORMATTER_REGISTRY: + raise ValueError( + f"Unknown formatter: `{formatter}`. Available formatters: {', '.join(FORMATTER_REGISTRY.keys())}" + ) + + formatter_cls = FORMATTER_REGISTRY[formatter] + return formatter_cls() + + +class StructuredPromptPair(PromptPairBase[StructuredPromptTemplate]): + system: StructuredPromptTemplate + user: Optional[StructuredPromptTemplate] = None diff --git a/tests/__init__.py b/tests/__init__.py index 4eb9068d..bde59cb5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,14 @@ import os +from typing import Final +from enum import Enum + +DATA_PATH: Final[str] = os.path.join(os.path.dirname(__file__), "data") + + +class PromptPath(str, Enum): + basic = os.path.join(DATA_PATH, "prompts", "prompt.yaml") + structured = os.path.join(DATA_PATH, "prompts", "structured_prompt.yaml") def get_data_filename(filename: str) -> str: - return os.path.join(os.path.dirname(__file__), "data", filename) + return os.path.join(DATA_PATH, filename) diff --git a/tests/data/prompts/prompt.yaml b/tests/data/prompts/prompt.yaml new file mode 100644 index 00000000..eabd3803 --- /dev/null +++ b/tests/data/prompts/prompt.yaml @@ -0,0 +1,15 @@ +kind: Prompt +metadata: + description: | + This is a prompt doing abc +llm: + model: gpt-4o-mini + params: + temperature: 0.3 +prompt: + system: + template: | + You are a helpful assistant. + user: + template: | + Answer the following questions: {question} diff --git a/tests/data/prompts/structured_prompt.yaml b/tests/data/prompts/structured_prompt.yaml new file mode 100644 index 00000000..9e690dda --- /dev/null +++ b/tests/data/prompts/structured_prompt.yaml @@ -0,0 +1,25 @@ +kind: StructuredPrompt +metadata: + description: | + This is a prompt doing xyz +llm: + model: claude-3-5-haiku-20241022 + params: + temperature: 0.2 +prompt: + system: + sections: + - name: Instructions + content: | + You are a helpful assistant. + sections: + - name: Hints + content: | + Answer the question in a concise manner. + formatter: XML + user: + sections: + - name: Question + content: | + What's the capital of France? + formatter: XML diff --git a/tests/integration/core/llm/test_llm_function_from_files.py b/tests/integration/core/llm/test_llm_function_from_files.py index 8e019beb..f218f015 100644 --- a/tests/integration/core/llm/test_llm_function_from_files.py +++ b/tests/integration/core/llm/test_llm_function_from_files.py @@ -5,10 +5,16 @@ from pydantic import BaseModel, Field from alphaswarm.core.llm import LLMFunctionTemplated +from alphaswarm.core.prompt import PromptConfig +from tests import PromptPath dotenv.load_dotenv() +class Response(BaseModel): + answer: str = Field(..., description="The answer to the question") + + class SimpleResponse(BaseModel): reasoning: str = Field(..., description="Reasoning behind the response") number: int = Field(..., ge=1, le=10, description="The random number between 1 and 10.") @@ -55,3 +61,23 @@ def test_llm_function_from_user_file() -> None: result = llm_func.execute(user_prompt_params={"min_value": 3, "max_value": 8}) assert isinstance(result, SimpleResponse) assert 3 <= result.number <= 8 + + +def test_llm_function_from_prompt_config() -> None: + llm_func = LLMFunctionTemplated.from_prompt_config( + response_model=Response, + prompt_config=PromptConfig.from_yaml(PromptPath.basic), + ) + + result = llm_func.execute(user_prompt_params={"question": "What's the capital of France?"}) + assert "Paris" in result.answer + + +def test_llm_function_from_structured_prompt_config() -> None: + llm_func = LLMFunctionTemplated.from_prompt_config_file( + response_model=Response, + prompt_config_path=PromptPath.structured, + ) + + result = llm_func.execute() + assert "Paris" in result.answer diff --git a/tests/unit/core/llm/test_llm_function_from_files.py b/tests/unit/core/llm/test_llm_function_from_files.py index 856ada83..d3410eca 100644 --- a/tests/unit/core/llm/test_llm_function_from_files.py +++ b/tests/unit/core/llm/test_llm_function_from_files.py @@ -1,30 +1,74 @@ import tempfile -from typing import Any import pytest from pydantic import BaseModel from alphaswarm.core.llm import LLMFunctionTemplated +from alphaswarm.core.prompt import PromptConfig + +from tests import PromptPath class Response(BaseModel): test: str -def get_sample_llm_function_from_files(**kwargs: Any) -> LLMFunctionTemplated[Response]: - return LLMFunctionTemplated.from_files( - model_id="test", - response_model=Response, - **kwargs, - ) - - def test_execute_with_user_prompt_params_but_no_template_raises() -> None: with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", suffix=".txt", delete=True) as system_file: system_file.write("Sample system prompt") system_file.flush() - llm_func = get_sample_llm_function_from_files(system_prompt_path=system_file.name) + llm_func = LLMFunctionTemplated.from_files( + model_id="test", + response_model=Response, + system_prompt_path=system_file.name, + ) with pytest.raises(ValueError, match="User prompt params provided but no user prompt template exists"): llm_func.execute(user_prompt_params={"test": "value"}) + + +def test_from_prompt_config() -> None: + llm_func_v1 = LLMFunctionTemplated.from_prompt_config( + response_model=Response, + prompt_config=PromptConfig.from_yaml(PromptPath.basic), + ) + + llm_func_v2 = LLMFunctionTemplated.from_prompt_config_file( + response_model=Response, + prompt_config_path=PromptPath.basic, + ) + + assert llm_func_v1._model_id == llm_func_v2._model_id == "gpt-4o-mini" + assert llm_func_v1.system_prompt == llm_func_v2.system_prompt == "You are a helpful assistant." + expected_user_prompt = "Answer the following questions: {question}" + assert llm_func_v1.user_prompt_template == llm_func_v2.user_prompt_template == expected_user_prompt + + +def test_from_structured_prompt_config() -> None: + llm_func = LLMFunctionTemplated.from_prompt_config_file( + response_model=Response, + prompt_config_path=PromptPath.structured, + ) + + assert llm_func._model_id == "claude-3-5-haiku-20241022" + expected_system_prompt = "\n".join( + [ + "", + " You are a helpful assistant.", + " ", + " Answer the question in a concise manner.", + " ", + "", + ] + ) + assert llm_func.system_prompt == expected_system_prompt + + expected_user_prompt = "\n".join( + [ + "", + " What's the capital of France?", + "", + ] + ) + assert llm_func.user_prompt_template == expected_user_prompt diff --git a/tests/unit/core/prompt/__init__.py b/tests/unit/core/prompt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/core/prompt/test_prompt.py b/tests/unit/core/prompt/test_prompt.py new file mode 100644 index 00000000..318498e1 --- /dev/null +++ b/tests/unit/core/prompt/test_prompt.py @@ -0,0 +1,329 @@ +import pytest +from pydantic import ValidationError + +from alphaswarm.core.prompt import PromptConfig +from alphaswarm.core.prompt.prompt import ( + PromptTemplate, + PromptPair, + LLMConfig, +) +from alphaswarm.core.prompt.structured import ( + PromptSection, + StructuredPromptTemplate, + StringPromptFormatter, + MarkdownPromptFormatter, + XMLPromptFormatter, + FORMATTER_REGISTRY, + StructuredPromptPair, +) +from tests import PromptPath + + +class TestPromptTemplate: + def test_get_template(self) -> None: + template = "Test template with {variable}" + prompt = PromptTemplate(template=template) + assert prompt.get_template() == template + + def test_whitespace_stripping(self) -> None: + template = " Template with whitespace " + prompt = PromptTemplate(template=template) + assert prompt.template == "Template with whitespace" + assert prompt.get_template() == "Template with whitespace" + + +class TestStructuredPromptTemplate: + def test_structured_prompt_template_with_string_formatter(self) -> None: + sections = [ + PromptSection(name="Introduction", content="This is an introduction"), + PromptSection(name="Instructions", content="Follow these instructions"), + ] + prompt = StructuredPromptTemplate(sections=sections, formatter="string") + expected = "\n".join(["Introduction", "This is an introduction", "Instructions", "Follow these instructions"]) + assert prompt.get_template() == expected + assert isinstance(prompt._formatter_obj, StringPromptFormatter) + + def test_structured_prompt_template_with_markdown_formatter(self) -> None: + sections = [ + PromptSection(name="Introduction", content="This is an introduction"), + PromptSection(name="Instructions", content="Follow these instructions"), + ] + prompt = StructuredPromptTemplate(sections=sections, formatter="markdown") + expected = "\n".join( + ["# Introduction", "", "This is an introduction", "# Instructions", "", "Follow these instructions"] + ) + assert prompt.get_template() == expected + assert isinstance(prompt._formatter_obj, MarkdownPromptFormatter) + + def test_structured_prompt_template_with_xml_formatter(self) -> None: + sections = [ + PromptSection(name="Introduction", content="This is an introduction"), + PromptSection(name="Instructions", content="Follow these instructions"), + ] + prompt = StructuredPromptTemplate(sections=sections, formatter="xml") + expected = "\n".join( + [ + "", + " This is an introduction", + "", + "", + " Follow these instructions", + "", + ] + ) + assert prompt.get_template() == expected + assert isinstance(prompt._formatter_obj, XMLPromptFormatter) + + def test_nested_sections(self) -> None: + sections = [ + PromptSection( + name="Main Section", + content="Main content", + sections=[PromptSection(name="Subsection", content="Subsection content")], + ) + ] + + prompt = StructuredPromptTemplate(sections=sections, formatter="string") + expected_string = "\n".join(["Main Section", "Main content", "Subsection", "Subsection content"]) + assert prompt.get_template() == expected_string + + prompt.set_formatter(MarkdownPromptFormatter()) + expected_md = "\n".join(["# Main Section", "", "Main content", "## Subsection", "", "Subsection content"]) + assert prompt.get_template() == expected_md + + prompt.set_formatter(XMLPromptFormatter()) + expected_xml = "\n".join( + [ + "", + " Main content", + " ", + " Subsection content", + " ", + "", + ] + ) + assert prompt.get_template() == expected_xml + + def test_invalid_formatter(self) -> None: + sections = [PromptSection(name="Test", content="Test content")] + + with pytest.raises(ValueError, match="Unknown formatter"): + StructuredPromptTemplate(sections=sections, formatter="invalid_formatter") + + def test_set_formatter_manually(self) -> None: + sections = [PromptSection(name="Test", content="Test content")] + prompt = StructuredPromptTemplate(sections=sections) + + prompt.set_formatter(XMLPromptFormatter()) + assert isinstance(prompt._formatter_obj, XMLPromptFormatter) + expected = "\n".join(["", " Test content", ""]) + assert prompt.get_template() == expected + + +class TestPromptFormatters: + def test_string_prompt_formatter_with_custom_prefix(self) -> None: + formatter = StringPromptFormatter(section_prefix=">> ") + sections = [ + PromptSection(name="Section1", content="Content1"), + PromptSection(name="Section2", content="Content2"), + ] + expected = "\n".join([">> Section1", "Content1", ">> Section2", "Content2"]) + assert formatter.format(sections) == expected + + def test_markdown_formatter_nested_headings(self) -> None: + formatter = MarkdownPromptFormatter() + sections = [ + PromptSection( + name="Level 1", + content="Content 1", + sections=[ + PromptSection( + name="Level 2", + content="Content 2", + sections=[PromptSection(name="Level 3", content="Content 3")], + ) + ], + ) + ] + expected = "\n".join( + ["# Level 1", "", "Content 1", "## Level 2", "", "Content 2", "### Level 3", "", "Content 3"] + ) + assert formatter.format(sections) == expected + + def test_xml_formatter_multiline_content(self) -> None: + formatter = XMLPromptFormatter() + sections = [PromptSection(name="Section", content="Line 1\nLine 2\nLine 3")] + expected = "\n".join(["
", " Line 1", " Line 2", " Line 3", "
"]) + assert formatter.format(sections) == expected + + def test_formatter_registry(self) -> None: + assert FORMATTER_REGISTRY["string"] == StringPromptFormatter + assert FORMATTER_REGISTRY["markdown"] == MarkdownPromptFormatter + assert FORMATTER_REGISTRY["xml"] == XMLPromptFormatter + + +class TestPromptPair: + def test_prompt_pair_with_system_only(self) -> None: + system_prompt = PromptTemplate(template="You are a helpful assistant.") + pair = PromptPair(system=system_prompt) + assert pair.system == system_prompt + assert pair.user is None + + def test_prompt_pair_with_system_and_user(self) -> None: + system_prompt = PromptTemplate(template="You are a helpful assistant.") + user_prompt = PromptTemplate(template="Help me with {task}.") + pair = PromptPair(system=system_prompt, user=user_prompt) + assert pair.system == system_prompt + assert pair.user == user_prompt + + +class TestStructuredPromptPair: + def test_structured_prompt_pair_system_only(self) -> None: + system_prompt = StructuredPromptTemplate( + sections=[PromptSection(name="System", content="You are a helpful assistant.")] + ) + pair = StructuredPromptPair(system=system_prompt) + assert pair.system == system_prompt + assert pair.user is None + + def test_structured_prompt_pair_with_system_and_user(self) -> None: + system_prompt = StructuredPromptTemplate( + sections=[PromptSection(name="System", content="You are a helpful assistant.")] + ) + user_prompt = StructuredPromptTemplate(sections=[PromptSection(name="User", content="Help me with this task.")]) + pair = StructuredPromptPair(system=system_prompt, user=user_prompt) + assert pair.system == system_prompt + assert pair.user == user_prompt + + +class TestLLMConfig: + def test_llm_config_with_model_only(self) -> None: + config = LLMConfig(model="gpt-4o") + assert config.model == "gpt-4o" + assert config.params is None + + def test_llm_config_with_params(self) -> None: + params = {"temperature": 0.7, "max_tokens": 100, "another_param": "value"} + config = LLMConfig(model="gpt-4o", params=params) + assert config.model == "gpt-4o" + assert config.params == params + + +class TestPromptConfig: + def test_prompt_config_initialization(self) -> None: + prompt_pair = PromptPair( + system=PromptTemplate(template="You are a helpful assistant."), + user=PromptTemplate(template="Help me with {task}."), + ) + metadata = {"version": "1.0", "author": "John Doe", "created_at": "2025-02-25"} + + config = PromptConfig(kind="Prompt", prompt=prompt_pair, metadata=metadata, llm=LLMConfig(model="gpt-4o")) + + assert config.kind == "Prompt" + assert config.prompt == prompt_pair + assert config.metadata == metadata + assert config.llm is not None + assert config.llm.model == "gpt-4o" + + def test_prompt_config_with_structured_prompt(self) -> None: + system_prompt = StructuredPromptTemplate( + sections=[PromptSection(name="System", content="You are a helpful assistant.")] + ) + prompt_pair = StructuredPromptPair(system=system_prompt) + + config = PromptConfig( + kind="StructuredPrompt", prompt=prompt_pair, metadata={"version": "1.0"}, llm=LLMConfig(model="gpt-4o") + ) + + assert config.kind == "StructuredPrompt" + assert isinstance(config.prompt, StructuredPromptPair) + assert config.has_llm_config is True + + def test_with_empty_metadata_and_llm(self) -> None: + prompt_pair = PromptPair( + system=PromptTemplate(template="You are a helpful assistant."), + user=PromptTemplate(template="Help me with {task}."), + ) + config = PromptConfig(kind="Prompt", prompt=prompt_pair) + + assert config.metadata is None + assert config.llm is None + + def test_prompt_config_invalid_kind(self) -> None: + system_prompt = PromptTemplate(template="You are a helpful assistant.") + prompt_pair = PromptPair(system=system_prompt) + + with pytest.raises(ValidationError): + PromptConfig(kind="InvalidKind", prompt=prompt_pair) # type: ignore + + def test_mixed_prompt_pair_validation(self) -> None: + system_prompt = PromptTemplate(template="System template") + user_prompt = StructuredPromptTemplate(sections=[PromptSection(name="User", content="User content")]) + + with pytest.raises(ValidationError): + StructuredPromptPair(system=system_prompt, user=user_prompt) # type: ignore + + def test_from_dict(self) -> None: + data = { + "kind": "Prompt", + "prompt": { + "system": {"template": "You are a helpful assistant."}, + "user": {"template": "Help me with this task."}, + }, + "metadata": {"version": "0.0.1"}, + "llm": {"model": "gpt-4o", "params": {"temperature": 0.7}}, + } + + config = PromptConfig(**data) # type: ignore + + assert config.kind == "Prompt" + assert isinstance(config.prompt, PromptPair) + assert isinstance(config.prompt.system, PromptTemplate) + assert config.prompt.system.template == "You are a helpful assistant." + assert isinstance(config.prompt.user, PromptTemplate) + assert config.prompt.user.template == "Help me with this task." + assert config.metadata == {"version": "0.0.1"} + assert config.llm is not None + assert config.llm.model == "gpt-4o" + assert config.llm.params == {"temperature": 0.7} + + def test_prompt_from_file(self) -> None: + config = PromptConfig.from_yaml(PromptPath.basic) + + assert config.kind == "Prompt" + assert isinstance(config.prompt, PromptPair) + assert isinstance(config.prompt.system, PromptTemplate) + assert config.prompt.system.get_template() == "You are a helpful assistant." + assert isinstance(config.prompt.user, PromptTemplate) + assert config.prompt.user.get_template() == "Answer the following questions: {question}" + assert config.metadata == {"description": "This is a prompt doing abc\n"} + assert config.has_llm_config is True + assert config.llm is not None + assert config.llm.model == "gpt-4o-mini" + assert config.llm.params == {"temperature": 0.3} + + def test_structured_prompt_from_file(self) -> None: + config = PromptConfig.from_yaml(PromptPath.structured) + + assert config.kind == "StructuredPrompt" + assert isinstance(config.prompt, StructuredPromptPair) + assert isinstance(config.prompt.system, StructuredPromptTemplate) + assert config.prompt.system.get_template() == "\n".join( + [ + "", + " You are a helpful assistant.", + " ", + " Answer the question in a concise manner.", + " ", + "", + ] + ) + assert isinstance(config.prompt.user, StructuredPromptTemplate) + assert config.prompt.user.get_template() == "\n".join( + ["", " What's the capital of France?", ""] + ) + assert config.metadata == {"description": "This is a prompt doing xyz\n"} + assert config.has_llm_config is True + assert config.llm is not None + assert config.llm.model == "claude-3-5-haiku-20241022" + assert config.llm.params == {"temperature": 0.2}