Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions alphaswarm/core/llm/llm_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from litellm.types.utils import ModelResponse
from pydantic import BaseModel

from ..prompt import PromptConfig
from .message import Message

litellm.modify_params = True # for calls with system message only for anthropic
Expand Down Expand Up @@ -169,6 +170,7 @@ def __init__(
user_prompt_template: Optional[str] = None,
system_prompt_params: Optional[Dict[str, Any]] = None,
max_retries: int = 3,
llm_params: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize an LLMFunctionTemplated instance.

Expand All @@ -179,12 +181,15 @@ def __init__(
user_prompt_template: Optional template for the user message
system_prompt_params: Parameters for formatting the system prompt if any
max_retries: Maximum number of retry attempts
llm_params: Additional keyword arguments to pass to the LLM client
"""
super().__init__(model_id=model_id, response_model=response_model, max_retries=max_retries)
self.system_prompt_template = system_prompt_template
self.system_prompt = self._format(system_prompt_template, system_prompt_params)
self.user_prompt_template = user_prompt_template

self._llm_params = llm_params or {}

def execute_with_completion(
self,
user_prompt_params: Optional[Dict[str, Any]] = None,
Expand All @@ -211,7 +216,7 @@ def execute_with_completion(

user_prompt = self._format(self.user_prompt_template, user_prompt_params)
messages.append(Message.user(user_prompt))
return self._execute_with_completion(messages=messages, **kwargs)
return self._execute_with_completion(messages=messages, **self._llm_params, **kwargs)

@classmethod
def from_files(
Expand All @@ -223,7 +228,7 @@ def from_files(
system_prompt_params: Optional[Dict[str, Any]] = None,
max_retries: int = 3,
) -> LLMFunctionTemplated[T_Response]:
"""Create an instance from template files.
"""Create an instance from template text files.

Args:
model_id: LiteLLM model ID to use
Expand All @@ -250,6 +255,62 @@ def from_files(
max_retries=max_retries,
)

@classmethod
def from_prompt_config(
cls,
response_model: Type[T_Response],
prompt_config: PromptConfig,
system_prompt_params: Optional[Dict[str, Any]] = None,
max_retries: int = 3,
) -> LLMFunctionTemplated[T_Response]:
"""Create an instance from prompt config object.

Args:
response_model: Pydantic model class for structuring responses
prompt_config: PromptConfig object
system_prompt_params: Parameters for formatting the system prompt
max_retries: Maximum number of retry attempts
"""
system_prompt_template = prompt_config.prompt.system.get_template()
user_prompt_template = prompt_config.prompt.user.get_template() if prompt_config.prompt.user else None

if prompt_config.llm is None:
raise ValueError("LLMConfig in PromptConfig is required to create an LLMFunction but was not set")
model_id = prompt_config.llm.model

return cls(
model_id=model_id,
response_model=response_model,
system_prompt_template=system_prompt_template,
user_prompt_template=user_prompt_template,
system_prompt_params=system_prompt_params,
max_retries=max_retries,
llm_params=prompt_config.llm.params,
)

@classmethod
def from_prompt_config_file(
cls,
response_model: Type[T_Response],
prompt_config_path: str,
system_prompt_params: Optional[Dict[str, Any]] = None,
max_retries: int = 3,
) -> LLMFunctionTemplated[T_Response]:
"""Create an instance from prompt config file.

Args:
response_model: Pydantic model class for structuring responses
prompt_config_path: Path to the prompt config yaml file
system_prompt_params: Parameters for formatting the system prompt
max_retries: Maximum number of retry attempts
"""
return cls.from_prompt_config(
response_model=response_model,
prompt_config=PromptConfig.from_yaml(prompt_config_path),
system_prompt_params=system_prompt_params,
max_retries=max_retries,
)

@staticmethod
def _format(template: str, params: Optional[Dict[str, Any]] = None) -> str:
"""Format the template string with the given optional parameters."""
Expand Down
3 changes: 3 additions & 0 deletions alphaswarm/core/prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .prompt import PromptConfig

__all__ = ["PromptConfig"]
21 changes: 21 additions & 0 deletions alphaswarm/core/prompt/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import abc
from typing import Annotated, Generic, Optional, TypeVar

from pydantic import BaseModel, StringConstraints

# helper class alias for str that's automatically stripped
StrippedStr = Annotated[str, StringConstraints(strip_whitespace=True)]


class PromptTemplateBase(BaseModel, abc.ABC):
@abc.abstractmethod
def get_template(self) -> str:
pass


T = TypeVar("T", bound="BaseModel")


class PromptPairBase(BaseModel, Generic[T]):
system: T
user: Optional[T] = None
49 changes: 49 additions & 0 deletions alphaswarm/core/prompt/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from typing import Any, Dict, Literal, Optional, Union

import yaml
from pydantic import BaseModel

from .base import PromptPairBase, PromptTemplateBase, StrippedStr
from .structured import StructuredPromptPair


class PromptTemplate(PromptTemplateBase):
template: StrippedStr

def get_template(self) -> str:
return self.template


class PromptPair(PromptPairBase[PromptTemplate]):
system: PromptTemplate
user: Optional[PromptTemplate] = None


class LLMConfig(BaseModel):
model: str
params: Optional[Dict[str, Any]] = None


class PromptConfig(BaseModel):
"""
Prompt configuration object.
Contains prompt pair, optional metadata, and optional LLM configuration.
If LLM configuration is specified, it could be used to generate an LLMFunction.
"""

kind: Literal["Prompt", "StructuredPrompt"]
prompt: Union[PromptPair, StructuredPromptPair]
metadata: Optional[Dict[str, Any]] = None
llm: Optional[LLMConfig] = None

@property
def has_llm_config(self) -> bool:
return self.llm is not None

@classmethod
def from_yaml(cls, path: str) -> PromptConfig:
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
return cls(**data)
102 changes: 102 additions & 0 deletions alphaswarm/core/prompt/structured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from __future__ import annotations

import abc
from typing import List, Mapping, Optional, Sequence, Type

from pydantic import BaseModel, model_validator

from .base import PromptPairBase, PromptTemplateBase, StrippedStr


class PromptSection(BaseModel):
name: str
content: Optional[StrippedStr] = None
sections: List[PromptSection] = []


class PromptFormatterBase(abc.ABC):
def format(self, sections: Sequence[PromptSection]) -> str:
return "\n".join(self._format_section(section) for section in sections)

@abc.abstractmethod
def _format_section(self, section: PromptSection) -> str:
pass


class StringPromptFormatter(PromptFormatterBase):
def __init__(self, section_prefix: str = "") -> None:
self.section_prefix = section_prefix

def _format_section(self, section: PromptSection) -> str:
parts = [f"{self.section_prefix}{section.name}"]
if section.content:
parts.append(section.content)
parts.extend([self._format_section(sec) for sec in section.sections])
return "\n".join(parts)


class MarkdownPromptFormatter(PromptFormatterBase):
def _format_section(self, section: PromptSection, indent: int = 1) -> str:
parts = ["", f"{'#' * indent} {section.name}", ""]
if section.content:
parts.append(section.content)
parts.extend([self._format_section(sec, indent + 1) for sec in section.sections])
return "\n".join(parts).strip()


class XMLPromptFormatter(PromptFormatterBase):
INDENT_DIFF: str = " "

def _format_section(self, section: PromptSection, indent: str = "") -> str:
name_snake_case = section.name.lower().replace(" ", "_")
parts = [f"{indent}<{name_snake_case}>"]

if section.content:
content_lines = section.content.split("\n")
content = "\n".join([f"{indent}{self.INDENT_DIFF}{line}" for line in content_lines])
parts.append(content)

parts.extend([self._format_section(sec, indent + self.INDENT_DIFF) for sec in section.sections])
parts.append(f"{indent}</{name_snake_case}>")
return "\n".join(parts)


FORMATTER_REGISTRY: Mapping[str, Type[PromptFormatterBase]] = {
"string": StringPromptFormatter,
"markdown": MarkdownPromptFormatter,
"xml": XMLPromptFormatter,
}


class StructuredPromptTemplate(PromptTemplateBase):
sections: List[PromptSection]
formatter: str = "string"
_formatter_obj: PromptFormatterBase = StringPromptFormatter() # default for mypy, will be overridden

def get_template(self) -> str:
return self._formatter_obj.format(self.sections)

def set_formatter(self, formatter: PromptFormatterBase) -> None:
self._formatter_obj = formatter

@model_validator(mode="after")
def formatter_obj_validator(self) -> StructuredPromptTemplate:
formatter_obj = self.formatter_string_to_obj(self.formatter)
self.set_formatter(formatter_obj)
return self

@staticmethod
def formatter_string_to_obj(formatter: str) -> PromptFormatterBase:
formatter = formatter.lower()
if formatter not in FORMATTER_REGISTRY:
raise ValueError(
f"Unknown formatter: `{formatter}`. Available formatters: {', '.join(FORMATTER_REGISTRY.keys())}"
)

formatter_cls = FORMATTER_REGISTRY[formatter]
return formatter_cls()


class StructuredPromptPair(PromptPairBase[StructuredPromptTemplate]):
system: StructuredPromptTemplate
user: Optional[StructuredPromptTemplate] = None
11 changes: 10 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import os
from typing import Final
from enum import Enum

DATA_PATH: Final[str] = os.path.join(os.path.dirname(__file__), "data")


class PromptPath(str, Enum):
basic = os.path.join(DATA_PATH, "prompts", "prompt.yaml")
structured = os.path.join(DATA_PATH, "prompts", "structured_prompt.yaml")


def get_data_filename(filename: str) -> str:
return os.path.join(os.path.dirname(__file__), "data", filename)
return os.path.join(DATA_PATH, filename)
15 changes: 15 additions & 0 deletions tests/data/prompts/prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
kind: Prompt
metadata:
description: |
This is a prompt doing abc
llm:
model: gpt-4o-mini
params:
temperature: 0.3
prompt:
system:
template: |
You are a helpful assistant.
user:
template: |
Answer the following questions: {question}
25 changes: 25 additions & 0 deletions tests/data/prompts/structured_prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
kind: StructuredPrompt
metadata:
description: |
This is a prompt doing xyz
llm:
model: claude-3-5-haiku-20241022
params:
temperature: 0.2
prompt:
system:
sections:
- name: Instructions
content: |
You are a helpful assistant.
sections:
- name: Hints
content: |
Answer the question in a concise manner.
formatter: XML
user:
sections:
- name: Question
content: |
What's the capital of France?
formatter: XML
26 changes: 26 additions & 0 deletions tests/integration/core/llm/test_llm_function_from_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
from pydantic import BaseModel, Field

from alphaswarm.core.llm import LLMFunctionTemplated
from alphaswarm.core.prompt import PromptConfig
from tests import PromptPath

dotenv.load_dotenv()


class Response(BaseModel):
answer: str = Field(..., description="The answer to the question")


class SimpleResponse(BaseModel):
reasoning: str = Field(..., description="Reasoning behind the response")
number: int = Field(..., ge=1, le=10, description="The random number between 1 and 10.")
Expand Down Expand Up @@ -55,3 +61,23 @@ def test_llm_function_from_user_file() -> None:
result = llm_func.execute(user_prompt_params={"min_value": 3, "max_value": 8})
assert isinstance(result, SimpleResponse)
assert 3 <= result.number <= 8


def test_llm_function_from_prompt_config() -> None:
llm_func = LLMFunctionTemplated.from_prompt_config(
response_model=Response,
prompt_config=PromptConfig.from_yaml(PromptPath.basic),
)

result = llm_func.execute(user_prompt_params={"question": "What's the capital of France?"})
assert "Paris" in result.answer


def test_llm_function_from_structured_prompt_config() -> None:
llm_func = LLMFunctionTemplated.from_prompt_config_file(
response_model=Response,
prompt_config_path=PromptPath.structured,
)

result = llm_func.execute()
assert "Paris" in result.answer
Loading