From 8a6c49cf570b49f862f649a9d3d4d03e7d16257e Mon Sep 17 00:00:00 2001 From: Ethan Date: Thu, 13 Feb 2025 20:45:44 +0000 Subject: [PATCH] forecasting agent example and tests --- .../forecasting/price_forecasting_tool.py | 90 +++++++++++++++++++ .../price_forecasting_system_prompt.md | 17 ++++ .../prompts/price_forecasting_user_prompt.md | 12 +++ examples/agents/price_forecaster.py | 45 ++++++++++ .../integration/tools/forecasting/__init__.py | 0 .../test_price_forecasting_tool.py | 45 ++++++++++ 6 files changed, 209 insertions(+) create mode 100644 alphaswarm/tools/forecasting/price_forecasting_tool.py create mode 100644 alphaswarm/tools/forecasting/prompts/price_forecasting_system_prompt.md create mode 100644 alphaswarm/tools/forecasting/prompts/price_forecasting_user_prompt.md create mode 100644 examples/agents/price_forecaster.py create mode 100644 tests/integration/tools/forecasting/__init__.py create mode 100644 tests/integration/tools/forecasting/test_price_forecasting_tool.py diff --git a/alphaswarm/tools/forecasting/price_forecasting_tool.py b/alphaswarm/tools/forecasting/price_forecasting_tool.py new file mode 100644 index 00000000..cbb8b4f8 --- /dev/null +++ b/alphaswarm/tools/forecasting/price_forecasting_tool.py @@ -0,0 +1,90 @@ +import os +from datetime import datetime +from decimal import Decimal +from typing import Any, List, Optional + +from alphaswarm.config import BASE_PATH +from alphaswarm.core.llm.llm_function import LLMFunctionFromPromptFiles +from alphaswarm.services.alchemy import HistoricalPriceBySymbol +from pydantic import BaseModel, Field +from smolagents import Tool + + +class PriceForecast(BaseModel): + timestamp: datetime = Field(description="The timestamp of the forecast") + price: Decimal = Field(description="The forecasted median price of the token") + lower_confidence_bound: Decimal = Field(description="The lower confidence bound of the forecast") + upper_confidence_bound: Decimal = Field(description="The upper confidence bound of the forecast") + + +class PriceForecastResponse(BaseModel): + reasoning: str = Field(description="The reasoning behind the forecast") + forecast: List[PriceForecast] = Field(description="The forecasted prices of the token") + + +class PriceForecastingTool(Tool): + name = "PriceForecastingTool" + description = """Forecast the price of a token based on historical price data and supporting context retrieved using other tools. + + Returns a `PriceForecastResponse` object. + + The `PriceForecastResponse` object has the following fields: + - reasoning: The reasoning behind the forecast + - historical_price_data: HistoricalPriceBySymbol object passed as input to the tool + - forecast: A list of `PriceForecast` objects, each containing a timestamp, a forecasted price, a lower confidence bound, and an upper confidence bound + + A `PriceForecast` object has the following fields: + - timestamp: The timestamp of the forecast + - price: The forecasted median price of the token + - lower_confidence_bound: The lower confidence bound of the forecast + - upper_confidence_bound: The upper confidence bound of the forecast + """ + inputs = { + "historical_price_data": { + "type": "object", + "description": "Historical price data for the token; output of AlchemyPriceHistoryBySymbol tool", + }, + "forecast_horizon": { + "type": "string", + "description": "Instructions for the forecast horizon", + }, + "supporting_context": { + "type": "object", + "description": """A list of strings, each representing an element of context to support the forecast. + Each element should include a source and a timeframe, e.g.: '...details... [Source: Web Search, Timeframe: last 2 days]'""", + "nullable": True, + }, + } + output_type = "object" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # Init the LLMFunction + self._llm_function = LLMFunctionFromPromptFiles( + model_id="anthropic/claude-3-5-sonnet-20241022", + response_model=PriceForecastResponse, + system_prompt_path=os.path.join( + BASE_PATH, "alphaswarm", "tools", "forecasting", "prompts", "price_forecasting_system_prompt.md" + ), + user_prompt_path=os.path.join( + BASE_PATH, "alphaswarm", "tools", "forecasting", "prompts", "price_forecasting_user_prompt.md" + ), + ) + + def forward( + self, + historical_price_data: HistoricalPriceBySymbol, + forecast_horizon: str, + supporting_context: Optional[List[str]] = None, + ) -> PriceForecastResponse: + response: PriceForecastResponse = self._llm_function.execute( + user_prompt_params={ + "supporting_context": ( + supporting_context if supporting_context is not None else "No additional context provided" + ), + "historical_price_data": str(historical_price_data), + "forecast_horizon": forecast_horizon, + } + ) + return response diff --git a/alphaswarm/tools/forecasting/prompts/price_forecasting_system_prompt.md b/alphaswarm/tools/forecasting/prompts/price_forecasting_system_prompt.md new file mode 100644 index 00000000..58217292 --- /dev/null +++ b/alphaswarm/tools/forecasting/prompts/price_forecasting_system_prompt.md @@ -0,0 +1,17 @@ +You are a specialized forecasting agent. +Your role is to analyze historical price data and supporting context to make token price predictions. + +You will be given a set of historical price data about one or more tokens, and a forecast horizon. + +You may optionally be given additional supporting context about the token, market, or other relevant information. +Make sure to factor in any background knowledge, satisfy any constraints, and respect any scenarios. + +Your output must include: +- Your reasoning about the forecast +- Your predictions for the prices at the forecast horizon +- Each prediction must include a timestamp and a price with lower and upper confidence bounds + +For the first forecast data point, use the last timestamp in the historical data so there is no gap between the historical data and the forecast (keep lower and upper confidence bounds the same as the last historical data point). + +Your reasoning should justify the direction, magnitude, and confidence bounds of the forecast. +If you are not confident in your ability to make an accurate prediction, your forecast, including the confidence bounds, should reflect that. \ No newline at end of file diff --git a/alphaswarm/tools/forecasting/prompts/price_forecasting_user_prompt.md b/alphaswarm/tools/forecasting/prompts/price_forecasting_user_prompt.md new file mode 100644 index 00000000..3698d378 --- /dev/null +++ b/alphaswarm/tools/forecasting/prompts/price_forecasting_user_prompt.md @@ -0,0 +1,12 @@ +# Inputs + +## Supporting Context +{supporting_context} + +## Historical Price Data +{historical_price_data} + +## Forecast Horizon +{forecast_horizon} + +Now please predict the value at the forecast horizon with your reasoning. \ No newline at end of file diff --git a/examples/agents/price_forecaster.py b/examples/agents/price_forecaster.py new file mode 100644 index 00000000..3f81d752 --- /dev/null +++ b/examples/agents/price_forecaster.py @@ -0,0 +1,45 @@ +import asyncio +from typing import List + +import dotenv +from alphaswarm.agent.agent import AlphaSwarmAgent +from alphaswarm.agent.clients import TerminalClient +from alphaswarm.tools.alchemy import AlchemyPriceHistoryBySymbol +from alphaswarm.tools.cookie.cookie_metrics import CookieMetricsBySymbol, CookieMetricsPaged +from alphaswarm.tools.forecasting.price_forecasting_tool import PriceForecastingTool +from smolagents import Tool + + +class ForecastingAgent(AlphaSwarmAgent): + def __init__(self) -> None: + tools: List[Tool] = [ + AlchemyPriceHistoryBySymbol(), + CookieMetricsBySymbol(), + CookieMetricsPaged(), + PriceForecastingTool(), + ] + + hints = """P.S. Here are some hints to help you succeed: + - Use the `AlchemyPriceHistoryBySymbol` tool to get the historical price data for the token + - Use the `CookieMetricsBySymbol` tool to get metrics about the subject token + - Use the `CookieMetricsPaged` tool to get a broader market overview of related AI agent tokens + - Use the `PriceForecastingTool` once you have gathered the necessary data to produce a forecast + - Please respond with the output of the `PriceForecastingTool` directly -- we don't need to reformat it. + """ + + super().__init__(tools=tools, model_id="anthropic/claude-3-5-sonnet-20241022", hints=hints) + + +async def main() -> None: + dotenv.load_dotenv() + + agent = ForecastingAgent() + + terminal = TerminalClient("AlphaSwarm terminal", agent) + await asyncio.gather( + terminal.start(), + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/integration/tools/forecasting/__init__.py b/tests/integration/tools/forecasting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/tools/forecasting/test_price_forecasting_tool.py b/tests/integration/tools/forecasting/test_price_forecasting_tool.py new file mode 100644 index 00000000..8ade1ec7 --- /dev/null +++ b/tests/integration/tools/forecasting/test_price_forecasting_tool.py @@ -0,0 +1,45 @@ +from datetime import datetime, timedelta, timezone +from decimal import Decimal + +import pytest + +from alphaswarm.services.alchemy import AlchemyClient +from alphaswarm.tools.forecasting.price_forecasting_tool import PriceForecastingTool + + +@pytest.fixture +def price_forecasting_tool() -> PriceForecastingTool: + return PriceForecastingTool() + + +def test_price_forecasting_tool(price_forecasting_tool: PriceForecastingTool, alchemy_client: AlchemyClient) -> None: + # Get historical price data for USDC + end = datetime.now(timezone.utc) + start = end - timedelta(days=7) # Get a week of historical data + historical_data = alchemy_client.get_historical_prices_by_symbol( + symbol="USDC", start_time=start, end_time=end, interval="1h" + ) + + # Call the forecasting tool + forecast_horizon = "24 hours" + supporting_context = ["USDC has maintained strong stability near $1 [Source: Market Data, Timeframe: last 7 days]"] + + result = price_forecasting_tool.forward( + historical_price_data=historical_data, forecast_horizon=forecast_horizon, supporting_context=supporting_context + ) + + # Verify response structure and basic expectations + assert result is not None + assert result.reasoning is not None + assert len(result.reasoning) > 0 + assert len(result.forecast) > 0 + + # Verify forecast data + for forecast in result.forecast: + assert isinstance(forecast.timestamp, datetime) + assert isinstance(forecast.price, Decimal) + assert isinstance(forecast.lower_confidence_bound, Decimal) + assert isinstance(forecast.upper_confidence_bound, Decimal) + assert forecast.lower_confidence_bound <= forecast.price <= forecast.upper_confidence_bound + # Since we're forecasting USDC which is a stablecoin, expect values close to 1 + assert Decimal("0.9") <= forecast.price <= Decimal("1.1")