diff --git a/.gitignore b/.gitignore index 4f7a5b048..99735c7c7 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,9 @@ paper-imgs/ testdata/enron-tiny.csv testdata/*/ testdata/*.tar.gz +tests/pytest/data/generator_messages/ +scripts/provider_stats/ +scripts/litellm_stats/ # python artifacts *.egg-info @@ -53,8 +56,14 @@ testdata/enron-eval/*.txt pyrightconfig.json myenv/ +pz-env/ # abacus-research data abacus-research/cuad-data/* abacus-research/opt-profiling-data/* abacus-research/parse-answer-errors/* + +# stats +scripts/litellm_stats/ +scripts/provider_stats/ +tests/pytest/data/generator_messages/ diff --git a/abacus-research/helper-scripts/mmqa-baseline.py b/abacus-research/helper-scripts/mmqa-baseline.py index bd0af2a04..b74477740 100644 --- a/abacus-research/helper-scripts/mmqa-baseline.py +++ b/abacus-research/helper-scripts/mmqa-baseline.py @@ -7,7 +7,7 @@ import numpy as np from openai import OpenAI -from palimpzest.constants import MODEL_CARDS, Cardinality, Model +from palimpzest.constants import Cardinality, Model from palimpzest.query.generators.generators import get_json_from_answer @@ -109,8 +109,9 @@ def f1(preds: list | None, targets: list): completion = client.chat.completions.create(**payload) # compute total cost - usd_per_input_token = MODEL_CARDS[model_name]["usd_per_input_token"] - usd_per_output_token = MODEL_CARDS[model_name]["usd_per_output_token"] + model = Model(model_name) + usd_per_input_token = model.get_usd_per_input_token() + usd_per_output_token = model.get_usd_per_output_token() input_tokens = completion.usage.prompt_tokens output_tokens = completion.usage.completion_tokens total_cost += input_tokens * usd_per_input_token + output_tokens * usd_per_output_token diff --git a/demos/caching-demo.py b/demos/caching-demo.py new file mode 100644 index 000000000..6efb26bbe --- /dev/null +++ b/demos/caching-demo.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 +""" +Realistic Demo showcasing prompt caching capabilities in Palimpzest. + +This demo processes multiple employee travel requests against a comprehensive +Corporate Travel Policy. The policy text (~2000 tokens) is included in the +system prompt, creating a realistic scenario for prompt caching where a large +static context is reused across multiple dynamic inputs. + +Workload: +- Context: A lengthy 10-page Corporate Travel & Expense Policy. +- Input: Short email requests from employees. +- Task: Analyze each request for policy compliance, identifying violations and reimbursable amounts. + +Supported caching providers: +- OpenAI (GPT-4o, GPT-4o-mini): Automatic prefix caching +- Anthropic (Claude 3.5 Sonnet/Haiku): Explicit cache_control markers +- Gemini: Implicit caching +""" + +import argparse +import os +import time +from typing import List + +from dotenv import load_dotenv + +import palimpzest as pz +from palimpzest.constants import Model +from palimpzest.core.lib.schemas import TextFile + +load_dotenv() + +# ============================================================================= +# MOCK DATA: CORPORATE TRAVEL POLICY (Static Context > 1024 tokens) +# ============================================================================= +CORPORATE_TRAVEL_POLICY = """ +GLOBAL CORP TRAVEL & EXPENSE POLICY (v2024.1) + +SECTION 1: OVERVIEW AND PHILOSOPHY +Global Corp expects employees to act responsibly and professionally when incurring and submitting costs. +The company will reimburse employees for reasonable and necessary expenses incurred during approved business travel. +This policy applies to all employees, contractors, and consultants. + +SECTION 2: AIR TRAVEL +2.1 Booking Window: All domestic flights must be booked at least 14 days in advance. International flights must be booked 21 days in advance. +2.2 Class of Service: + - Economy Class: Required for all domestic flights under 6 hours. + - Premium Economy: Allowed for domestic flights over 6 hours or international flights under 8 hours. + - Business Class: Allowed for international flights exceeding 8 hours duration. + - First Class: Strictly prohibited unless approved by the CEO. +2.3 Ancillary Fees: + - Checked Bags: Up to two bags reimbursed for trips > 3 days. One bag for trips <= 3 days. + - Wi-Fi: Reimbursed only if business justification is provided (e.g., "urgent client deadline"). + - Seat Selection: Fees > $50 require VP approval. + +SECTION 3: LODGING +3.1 Hotel Caps (Nightly Rates excluding taxes): + - Tier 1 Cities (NY, London, Tokyo, SF, Zurich): $350 USD + - Tier 2 Cities (Chicago, Paris, Berlin, Austin): $250 USD + - All Other Locations: $175 USD +3.2 Room Type: Standard single rooms only. Suites are prohibited. +3.3 Laundry: Reasonable laundry expenses reimbursed for trips exceeding 5 consecutive nights. + +SECTION 4: MEALS AND ENTERTAINMENT +4.1 Daily Meal Allowance (Per Diem): + - Tier 1 Cities: $100/day + - Tier 2 Cities: $75/day + - Others: $60/day +4.2 Client Entertainment: + - Must include at least one current or prospective client. + - Cap is $150 per person (including employees). + - Names and affiliations of all attendees must be documented. +4.3 Alcohol: + - Reimbursable only with dinner. + - Moderate consumption allowed (max 2 drinks per person). + - "Top Shelf" liquors prohibited. + +SECTION 5: GROUND TRANSPORTATION +5.1 Ride Share/Taxi: Preferred mode for travel between airport and hotel. +5.2 Car Rentals: + - Class: Intermediate/Mid-size or smaller. + - Insurance: Decline CDW/LDW (covered by corporate policy). + - Fuel: Pre-paid fuel options are prohibited; cars must be returned full. +5.3 Rail: Economy/Standard class only. Acela Business Class permitted for Northeast Corridor travel. + +SECTION 6: MISCELLANEOUS +6.1 Tipping: + - Meals: 15-20% + - Taxis: 10-15% + - Bellhop: $1-2 per bag +6.2 Non-Reimbursable Items: + - Personal grooming/toiletries. + - Fines (parking, speeding). + - Airline club memberships. + - In-room movies. + - Lost luggage/property. + +SECTION 7: SUBMISSION PROCESS +Expenses must be submitted within 30 days of trip completion. Receipts required for all expenses > $25. +""" + +# ============================================================================= +# MOCK DATA: EMPLOYEE REQUESTS (Dynamic Inputs) +# ============================================================================= +EMPLOYEE_REQUESTS = [ + # Request 1: Compliant + """Subject: Trip to London + I booked a flight to London (8.5 hours) in Business Class for the client summit. + Hotel is $320/night. Meal expenses were about $90/day. + Receipts attached.""", + # Request 2: Violation (Booking window & First Class) + """Subject: Urgent NY Trip + I need to fly to New York tomorrow. Booked First Class because it was the only seat left. + Hotel is the Ritz at $500/night. + Also expensed $40 for in-flight Wi-Fi to finish the Q3 report.""", + # Request 3: Violation (Car Rental & Alcohol) + """Subject: Austin Conference + Rented a luxury SUV for the team in Austin. + Dinner with the team (no clients) came to $800 ($200/person) including 3 bottles of wine. + Hotel was $240/night.""", + # Request 4: Compliant (Tier 2 City) + """Subject: Berlin Site Visit + Flew Economy to Berlin. Hotel was $220/night. + Took a taxi from TXL ($45 + $5 tip). + Daily meals averaged $70.""", + # Request 5: Violation (Misc items) + """Subject: Tokyo Tech Symposium + Trip duration: 4 days. + Expensed: + - Flight (Premium Econ, 11 hours) + - Hotel ($340/night) + - Laundry service ($60) + - Forgotten toothbrush replacement ($15) + - Parking ticket ($50) + """, +] + +# Output Schema +OUTPUT_SCHEMA = [ + {"name": "status", "type": str, "desc": "One of: 'COMPLIANT', 'PARTIAL_VIOLATION', 'MAJOR_VIOLATION'"}, + { + "name": "violations", + "type": str, + "desc": "A list of specific policy violations found, referencing the specific section numbers (e.g., 'Violation of Section 2.2'). If compliant, return 'None'.", + }, + { + "name": "reimbursable_summary", + "type": str, + "desc": "A concise summary of what should be reimbursed vs rejected based on the policy text.", + }, + { + "name": "flag_for_review", + "type": bool, + "desc": "True if the request requires manual review by a manager (e.g. for high amounts or ambiguous justifications).", + }, +] + +TASK_DESC = f""" +You are an AI auditor for Global Corp. Your job is to review employee travel expense descriptions against the Corporate Travel Policy. +The full policy text is provided below. + +{CORPORATE_TRAVEL_POLICY} + +Analyze the input email and determine if the expenses adhere to the policy. +""" + + +class TravelRequestDataset(pz.IterDataset): + """Custom dataset that provides travel requests as text records.""" + + def __init__(self, requests: List[str]): + super().__init__(id="travel_requests", schema=TextFile) + self.requests = requests + + def __len__(self): + return len(self.requests) + + def __getitem__(self, idx: int): + return { + "filename": f"request_{idx + 1}.txt", + "contents": self.requests[idx], + } + + +# Model mapping (Same as original) +MODEL_MAPPING = { + "gpt-4o": Model.GPT_4o, + "gpt-4o-mini": Model.GPT_4o_MINI, + "claude-4-0-sonnet": Model.CLAUDE_4_SONNET, + # "claude-3-7-sonnet": Model.CLAUDE_3_7_SONNET, # deprecated model testing + "claude-4-5-haiku": Model.CLAUDE_4_5_HAIKU, + "gemini-2.5-flash": Model.GOOGLE_GEMINI_2_5_FLASH, + # "deepseek-v3": Model.DEEPSEEK_V3, +} + + +def get_model_from_string(model_str: str) -> Model: + if model_str.lower() in MODEL_MAPPING: + return MODEL_MAPPING[model_str.lower()] + for model in Model: + if model.value.lower() == model_str.lower(): + return model + raise ValueError(f"Unknown model: {model_str}") + + +def print_cache_stats(execution_stats): + """Print cache-related statistics from execution.""" + print("\n" + "=" * 60) + print(" CACHE STATISTICS & COST ANALYSIS") + print("=" * 60) + + # Token counts are now disjoint: + # - input_text_tokens: regular (non-cached) input tokens + # - cache_read_tokens: tokens read from cache (hits) + # - cache_creation_tokens: tokens written to cache + regular_input = execution_stats.input_text_tokens + cache_read = execution_stats.cache_read_tokens + cache_creation = execution_stats.cache_creation_tokens + total_output = execution_stats.output_text_tokens + total_embedding = execution_stats.embedding_input_tokens + + # Logical total = regular + cache read + cache creation + logical_total_input = regular_input + cache_read + cache_creation + + print(f"{'Metric':<35} | {'Count':<15}") + print("-" * 55) + print(f"{'Logical Total Input Tokens':<35} | {logical_total_input:,}") + print(f"{' - Regular Input (full rate)':<35} | {regular_input:,}") + print(f"{' - Cache Read (discounted)':<35} | {cache_read:,}") + print(f"{' - Cache Creation':<35} | {cache_creation:,}") + print("-" * 55) + print(f"{'Total Output Tokens':<35} | {total_output:,}") + if total_embedding > 0: + print(f"{'Total Embedding Input Tokens':<35} | {total_embedding:,}") + print("-" * 55) + print(f"{'Total Execution Cost':<35} | ${execution_stats.total_execution_cost:.6f}") + + # Calculate and display cache hit rate + # Hit rate = cache_read / (regular_input + cache_read) + total_cacheable = regular_input + cache_read + if total_cacheable > 0: + hit_rate = (cache_read / total_cacheable) * 100 + print(f"\nCache Hit Rate: {hit_rate:.1f}%") + + +def main(): + parser = argparse.ArgumentParser(description="Demo showcasing prompt caching in Palimpzest") + parser.add_argument("--model", type=str, default="gpt-4o-mini", help="Model to use") + parser.add_argument("--num-records", type=int, default=5, help="Number of requests to process") + parser.add_argument("--verbose", action="store_true", help="Enable verbose output") + parser.add_argument("--profile", action="store_true", help="Save profiling data") + + args = parser.parse_args() + model = get_model_from_string(args.model) + + # Validate env vars (Simplified for brevity) + if model.is_provider_openai() and not os.getenv("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY not set") + return + if model.is_provider_anthropic() and not os.getenv("ANTHROPIC_API_KEY"): + print("ERROR: ANTHROPIC_API_KEY not set") + return + if (model.is_provider_google_ai_studio() or model.is_provider_vertex_ai()) and not os.getenv("GOOGLE_API_KEY"): + print("ERROR: GOOGLE_API_KEY not set") + return + + print("=" * 60) + print(" PZ CACHING DEMO: CORPORATE AUDIT") + print("=" * 60) + print(f"Model: {model.value}") + print( + f"Policy Context Size: ~{len(CORPORATE_TRAVEL_POLICY.split())} words (~{int(len(CORPORATE_TRAVEL_POLICY.split()) * 1.3)} tokens)" + ) + + # Repeat the request list if user wants more records than we have mocks + base_requests = EMPLOYEE_REQUESTS + requests = [] + while len(requests) < args.num_records: + requests.extend(base_requests) + requests = requests[: args.num_records] + + print(f"Processing {len(requests)} travel requests...") + + # Build Plan + dataset = TravelRequestDataset(requests) + + # The 'desc' field incorporates the huge CORPORATE_TRAVEL_POLICY string. + # This ensures the System Prompt is large (>1024 tokens) and identical for all records. + plan = dataset.sem_map(OUTPUT_SCHEMA, desc=TASK_DESC) + + config = pz.QueryProcessorConfig( + policy=pz.MaxQuality(), + verbose=args.verbose, + execution_strategy="sequential", # Sequential often easier to debug caching behavior initially + available_models=[model], + ) + + start_time = time.time() + result = plan.run(config) + end_time = time.time() + + # Output Results + print("\n" + "=" * 60) + print(" AUDIT RESULTS") + print("=" * 60) + for i, record in enumerate(result.data_records): + print(f"\n[Request {i + 1}]") + print(f"Status: {record.status}") + print(f"Violations: {record.violations}") + print(f"Summary: {record.reimbursable_summary}") + + print_cache_stats(result.execution_stats) + print(f"\nWall Clock Time: {end_time - start_time:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/demos/vllm-demo.py b/demos/vllm-demo.py new file mode 100644 index 000000000..e03f75036 --- /dev/null +++ b/demos/vllm-demo.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +""" +Minimal demo for running a vLLM model with Palimpzest. + +Prerequisites: + 1. Start a vLLM server serving a small model, e.g.: + vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 + 2. Run this script: + python demos/vllm-demo.py \ + --api-base http://localhost:8000/v1 \ + --model-id openai/Qwen/Qwen2.5-1.5B-Instruct +""" +import argparse +import os + +from pydantic import BaseModel, Field + +import palimpzest as pz + + +class SentimentResult(BaseModel): + sentiment: str = Field(description="The sentiment of the text: positive, negative, or neutral") + + +def main(): + parser = argparse.ArgumentParser(description="Run a minimal vLLM demo") + parser.add_argument("--api-base", type=str, required=True, help="vLLM server base URL (e.g. http://localhost:8000/v1)") + parser.add_argument("--model-id", type=str, required=True, help="Model ID for litellm (e.g. openai/Qwen/Qwen2.5-1.5B-Instruct)") + parser.add_argument("--max-tokens", type=int, default=128, help="Max tokens for completion") + parser.add_argument("--verbose", action="store_true", default=False) + args = parser.parse_args() + + # Create the vLLM model with api_base and kwargs on the Model instance + vllm_model = pz.Model(args.model_id, api_base=args.api_base, max_tokens=args.max_tokens) + + # Load the enron-tiny dataset + data_path = os.path.join(os.path.dirname(__file__), "..", "testdata", "enron-tiny") + dataset = pz.TextFileDataset(id="test-sentiment", path=data_path) + dataset = dataset.sem_map(SentimentResult, desc="Classify the sentiment of the text") + + # Configure with vLLM model + config = pz.QueryProcessorConfig( + policy=pz.MaxQuality(), + available_models=[vllm_model], + execution_strategy="sequential", + optimizer_strategy="pareto", + verbose=args.verbose, + ) + + output = dataset.run(config) + for record in output: + print(record) + + +if __name__ == "__main__": + main() diff --git a/evals/quest/eval.py b/evals/quest/eval.py new file mode 100644 index 000000000..0d7a48339 --- /dev/null +++ b/evals/quest/eval.py @@ -0,0 +1,159 @@ +import argparse +import copy +import json +import os +import random +import time + +import palimpzest as pz + + +def prepare_docs_for_query(items: list, gt_docs: list) -> list: + items = copy.copy(items) + random.shuffle(items) + final_items = [doc for doc in items if doc["title"] in gt_docs] + while len(final_items) < 1000 and len(items) > 0: + item = items.pop(0) + if item not in final_items: + final_items.append(item) + return final_items + + +def palimpzest_run_query(query: dict, documents: list) -> list[str]: + gt_docs = query["docs"] + items = prepare_docs_for_query(documents, gt_docs) + + schema = [ + {"name": "title", "type": str, "desc": "Document title"}, + {"name": "text", "type": str, "desc": "Document content"}, + ] + + dataset = pz.MemoryDataset( + id="quest-docs", + vals=items, + schema=schema, + ) + + query_text = query["query"] + plan = dataset.sem_filter( + f'This document is relevant to the entity-seeking query: "{query_text}". ' + "Return True if the document helps answer the query, False otherwise.", + depends_on=["text"], + ).project(["title"]) + + config = pz.QueryProcessorConfig( + policy=pz.MaxQuality(), + execution_strategy="parallel", + progress=True, + ) + output = plan.run(config) + execution_stats = output.execution_stats + time_secs = execution_stats.total_execution_time if execution_stats else 0.0 + cost = execution_stats.total_execution_cost if execution_stats else 0.0 + return [record["title"] for record in output], time_secs, cost + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate Palimpzest on QUEST") + parser.add_argument( + "--domain", + type=str, + required=True, + choices=["films", "books"], + help="The domain to evaluate.", + ) + parser.add_argument( + "--queries", + type=str, + required=True, + help="Path to the file containing the queries (e.g. test.jsonl).", + ) + parser.add_argument( + "--documents", + type=str, + default="data/documents.jsonl", + help="Path to documents.jsonl (QUEST format: title, text per line).", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Limit number of queries to evaluate (for debugging).", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for document shuffling.", + ) + args = parser.parse_args() + + random.seed(args.seed) + + if not os.path.exists(args.documents): + raise FileNotFoundError( + f"Documents file not found: {args.documents}\n" + ) + with open(args.documents) as f: + documents = [json.loads(line) for line in f] + + queries = [] + with open(args.queries) as f: + for line in f: + d = json.loads(line) + if d["metadata"]["domain"] == args.domain: + queries.append(d) + + if args.limit: + queries = queries[: args.limit] + + results = [] + for i, query in enumerate(queries): + print(f"[{i + 1}/{len(queries)}] Executing query: {query['query']}") + pred_docs, cur_time, cur_cost = palimpzest_run_query(query, documents) + + gt_docs = query["docs"] + preds = set(pred_docs) + labels = set(gt_docs) + + tp = sum(1 for pred in preds if pred in labels) + fp = len(preds) - tp + fn = sum(1 for label in labels if label not in preds) + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + result = { + "query": query["query"], + "predicted_docs": pred_docs, + "ground_truth_docs": gt_docs, + "precision": precision, + "recall": recall, + "f1_score": f1, + "time": cur_time, + "cost": cur_cost + } + results.append(result) + + ts = int(time.time()) + out_path = f"results_{args.domain}_{ts}.json" + with open(out_path, "w") as f: + json.dump(results, f, indent=4) + print(f"\nResults saved to {out_path}") + + n = len(results) + avg_precision = sum(r["precision"] for r in results) / n + avg_recall = sum(r["recall"] for r in results) / n + avg_f1 = sum(r["f1_score"] for r in results) / n + avg_time = sum(r["time"] for r in results) / n + avg_cost = sum(r["cost"] for r in results) / n + + print(f"Average Precision: {avg_precision:.4f}") + print(f"Average Recall: {avg_recall:.4f}") + print(f"Average F1 Score: {avg_f1:.4f}") + print(f"Average Time: {avg_time:.4f}s") + print(f"Average Cost: {avg_cost:.4f}$") + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 3ccd07de8..aa8c3cdbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "palimpzest" -version = "1.3.4" +version = "1.5.3" description = "Palimpzest is a system which enables anyone to process AI-powered analytical queries simply by defining them in a declarative language" readme = "README.md" requires-python = ">=3.12" @@ -9,14 +9,15 @@ authors = [ {name="MIT DSG Semantic Management Lab", email="michjc@csail.mit.edu"}, ] dependencies = [ - "anthropic>=0.46.0", + "anthropic>=0.79.0", "beautifulsoup4>=4.13.4", "chromadb>=1.0.15", "colorama>=0.4.6", "datasets>=4.0.0", - "fastapi~=0.115.0", + "fastapi>=0.115.0", + "google-genai>=1.0.0", "gradio>=5.26.0", - "litellm>=1.76.1", + "litellm>=1.81.11, <1.82.7", "numpy==2.0.2", "openai>=1.0", "pandas>=2.1.1", @@ -63,8 +64,7 @@ where = ["src"] namespaces = false [tool.setuptools.package-data] -"*" = ["*.txt", "*.rst", "*.md"] - +"*" = ["*.txt", "*.rst", "*.md", "*.json"] [tool.pytest.ini_options] testpaths = ["tests/pytest"] diff --git a/scripts/capture_litellm_stats.py b/scripts/capture_litellm_stats.py new file mode 100755 index 000000000..4a4506ecc --- /dev/null +++ b/scripts/capture_litellm_stats.py @@ -0,0 +1,532 @@ +#!/usr/bin/env python3 +""" +Script to invoke LLM providers through LiteLLM and capture token/cost statistics. + +This script: +1. Loads messages from JSON files generated by generate_test_messages.py +2. Sends requests through LiteLLM (the same path palimpzest uses) +3. Saves all usage metadata and response stats returned by LiteLLM +4. Waits 10 seconds +5. Sends the request again and saves the second set of stats + +This allows us to compare LiteLLM's reported statistics with: +- Direct provider API calls (from capture_provider_stats.py) +- Palimpzest's generator stats tracking + +Supported providers: +- Anthropic: claude-sonnet-4-5-20250929 (text, image, text+image) +- Google/Gemini: gemini-2.5-flash (all seven modality combinations) +- OpenAI: gpt-4o-2024-08-06 (text, image, text+image) +- OpenAI: gpt-4o-audio-preview (text+audio, audio) +- Azure: gpt-4o via Azure OpenAI (text, image, text+image) + +Output files are saved to: scripts/litellm_stats/ +""" + +import argparse +import json +import os +import sys +import time +import uuid +from typing import Any + +import litellm +from litellm.integrations.custom_logger import CustomLogger + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import contextlib + +from palimpzest.constants import Model + + +# ============================================================================= +# RAW RESPONSE CAPTURE CALLBACK +# ============================================================================= +class RawProviderStatsCapture(CustomLogger): + """ + Custom LiteLLM callback to capture raw provider usage stats before normalization. + + LiteLLM normalizes all responses to OpenAI format, which loses provider-specific + details like Gemini's per-modality token breakdowns. This callback captures the + original provider response data. + """ + + def __init__(self): + self.last_raw_response = None + self.last_raw_usage = None + self.last_provider = None + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + """Called after a successful LLM API call.""" + try: + # Store the provider info + self.last_provider = kwargs.get("custom_llm_provider") or kwargs.get("model", "").split("/")[0] + + # Try to get the original response from hidden params + if hasattr(response_obj, "_hidden_params") and response_obj._hidden_params: + hidden = response_obj._hidden_params + self.last_raw_response = hidden.get("original_response") + + # For some providers, the raw response might be in different locations + if self.last_raw_response is None: + self.last_raw_response = hidden.get("raw_response") + + # Try to extract raw usage from the response object itself + # Some providers have additional attributes that aren't in model_dump() + if hasattr(response_obj, "_response_ms"): + if self.last_raw_response is None: + self.last_raw_response = {} + self.last_raw_response["_response_ms"] = response_obj._response_ms + + # For Vertex AI / Gemini, check for provider-specific usage fields + if hasattr(response_obj, "vertex_ai_usage_metadata"): + self.last_raw_usage = response_obj.vertex_ai_usage_metadata + elif hasattr(response_obj, "_vertex_ai_response"): + self.last_raw_response = response_obj._vertex_ai_response + + except Exception as e: + # Don't let callback errors break the main flow + print(f" [Callback] Error capturing raw response: {e}") + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + """Called after a failed LLM API call.""" + self.last_raw_response = None + self.last_raw_usage = None + self.last_provider = None + + def reset(self): + """Reset captured data for next request.""" + self.last_raw_response = None + self.last_raw_usage = None + self.last_provider = None + + def get_captured_data(self) -> dict[str, Any]: + """Return captured raw data and reset for next request.""" + data = { + "raw_provider_response": self.last_raw_response, + "raw_provider_usage": self.last_raw_usage, + "detected_provider": self.last_provider, + } + self.reset() + return data + + +# Global callback instance +raw_stats_capture = RawProviderStatsCapture() + +# Register the callback with LiteLLM +litellm.callbacks = [raw_stats_capture] + +# Enable return of response headers (helps with some providers) +litellm.return_response_headers = True + + +# ============================================================================= +# PROVIDER CONFIGURATIONS +# ============================================================================= +# Maps provider name to Model enum and supported modalities +# The Model enum is used for: +# 1. Getting the LiteLLM model name via model.value +# 2. Initializing PromptManager which needs a Model enum +PROVIDER_MODALITY_SUPPORT = { + "anthropic": { + "model": Model.CLAUDE_4_5_SONNET, + "supported_modalities": ["text-only", "image-only", "text-image"], + # Note: Anthropic does not support audio + }, + "openai": { + "model": Model.GPT_4o, + "supported_modalities": ["text-only", "image-only", "text-image"], + }, + "openai-audio": { + "model": Model.GPT_4o_AUDIO_PREVIEW, + "supported_modalities": ["audio-only", "text-audio"], + }, + "gemini": { + "model": Model.GOOGLE_GEMINI_2_5_FLASH, + "supported_modalities": [ + "text-only", + "image-only", + "audio-only", + "text-image", + "text-audio", + "image-audio", + "text-image-audio", + ], + }, + "vertex_ai": { + "model": Model.GEMINI_2_5_FLASH, + "supported_modalities": [ + "text-only", + "image-only", + "audio-only", + "text-image", + "text-audio", + "image-audio", + "text-image-audio", + ], + }, + "azure": { + "model": Model.AZURE_GPT_4o, + "supported_modalities": ["text-only", "image-only", "text-image"], + }, +} + + +def load_messages(modality: str, provider: str, messages_dir: str) -> list[dict]: + """Load messages from JSON file for a given modality/provider combination.""" + filepath = os.path.join(messages_dir, f"{modality}_{provider}.json") + with open(filepath) as f: + return json.load(f) + + +def transform_messages_for_litellm(messages: list[dict]) -> list[dict]: + """ + Transform palimpzest message format to LiteLLM-compatible format. + + LiteLLM expects: + - Messages with role and content + - Content can be string or list of content blocks + - No 'type' field at the message level (that's palimpzest-specific) + + This function consolidates multiple user messages with different types + into single messages with combined content. + """ + litellm_messages = [] + + for msg in messages: + role = msg.get("role") + msg_type = msg.get("type") + content = msg.get("content") + + if role == "system": + # System messages pass through as-is + # Content may be string or list of content blocks (for caching) + litellm_messages.append({"role": "system", "content": content}) + + elif role == "user": + # User messages need consolidation + if msg_type == "text": + # Text content - string or list of content blocks + if litellm_messages and litellm_messages[-1]["role"] == "user": + # Merge with existing user message + existing = litellm_messages[-1]["content"] + if isinstance(existing, str): + if isinstance(content, str): + litellm_messages[-1]["content"] = [ + {"type": "text", "text": existing}, + {"type": "text", "text": content}, + ] + else: + litellm_messages[-1]["content"] = [ + {"type": "text", "text": existing} + ] + content + else: + if isinstance(content, str): + existing.append({"type": "text", "text": content}) + else: + existing.extend(content) + else: + litellm_messages.append({"role": "user", "content": content}) + + elif msg_type == "image": + # Image content - list of image_url blocks + if litellm_messages and litellm_messages[-1]["role"] == "user": + existing = litellm_messages[-1]["content"] + if isinstance(existing, str): + litellm_messages[-1]["content"] = [ + {"type": "text", "text": existing} + ] + content + else: + existing.extend(content) + else: + litellm_messages.append({"role": "user", "content": content}) + + elif msg_type == "input_audio": + # Audio content - list of input_audio blocks + if litellm_messages and litellm_messages[-1]["role"] == "user": + existing = litellm_messages[-1]["content"] + if isinstance(existing, str): + litellm_messages[-1]["content"] = [ + {"type": "text", "text": existing} + ] + content + else: + existing.extend(content) + else: + litellm_messages.append({"role": "user", "content": content}) + + elif role == "assistant": + litellm_messages.append({"role": "assistant", "content": content}) + + return litellm_messages + + +def call_litellm_api( + messages: list[dict], + model: Model, + provider: str, + cache_key: str | None = None, +) -> dict[str, Any]: + """ + Call LiteLLM completion API and return all usage statistics. + + This function captures both: + - Option A: Raw provider usage via callback (if available) + - Option C: Normalized LiteLLM usage (fallback) + + Args: + messages: List of message dicts (palimpzest format) + model: Model enum (used for model name and provider detection) + provider: Provider name for logging + cache_key: Optional prompt_cache_key for OpenAI sticky routing to same cache shard + + Returns dict with: + - usage: Normalized usage dict from LiteLLM response (Option C fallback) + - usage_raw: Raw provider usage if captured via callback (Option A) + - response_content: First 200 chars of response + - model: Model used + - raw_response: Full response object serialized + """ + # Reset the callback to capture fresh data for this request + raw_stats_capture.reset() + + # Transform messages to LiteLLM format + litellm_messages = transform_messages_for_litellm(messages) + + # Get the LiteLLM model name from the Model enum + model_name = model.value + + # Set up completion kwargs + completion_kwargs = { + "temperature": 0.0, + } + + # Add modalities for audio models + if "audio" in model_name.lower(): + completion_kwargs["modalities"] = ["text"] + + # Apply provider-specific caching configuration + # Messages from generator_messages already have cache_control markers for Anthropic + if (model.is_provider_openai() or model.is_provider_azure()) and cache_key: + # OpenAI: Use prompt_cache_key for sticky routing to the same cache shard + # https://platform.openai.com/docs/guides/prompt-caching + completion_kwargs["extra_body"] = {"prompt_cache_key": cache_key} + + # Make the LiteLLM call + response = litellm.completion( + model=model_name, + messages=litellm_messages, + **completion_kwargs, + ) + + # ========================================================================== + # Option C (Fallback): Extract normalized usage stats from LiteLLM response + # ========================================================================== + usage_normalized = {} + if response.usage: + usage_normalized = response.usage.model_dump() + + # ========================================================================== + # Option A: Get raw provider data captured by callback + # ========================================================================== + callback_data = raw_stats_capture.get_captured_data() + usage_raw = callback_data.get("raw_provider_usage") + + # Also try to extract raw usage from _hidden_params + hidden_params = {} + try: + if hasattr(response, "_hidden_params") and response._hidden_params: + hidden_params = dict(response._hidden_params) + # Some providers store original response here + if "original_response" in hidden_params: + original = hidden_params["original_response"] + if isinstance(original, dict) and "usage_metadata" in original: + usage_raw = original["usage_metadata"] + elif hasattr(original, "usage_metadata"): + with contextlib.suppress(Exception): + usage_raw = original.usage_metadata.model_dump() if hasattr(original.usage_metadata, "model_dump") else dict(original.usage_metadata) + except Exception: + pass + + # Get response text safely + try: + response_text = ( + response.choices[0].message.content[:200] + if response.choices and response.choices[0].message.content + else None + ) + except Exception: + response_text = None + + # Serialize the full response for debugging + try: + raw_response = response.model_dump() + except Exception: + raw_response = str(response) + + return { + "provider": provider, + "model": model_name, + "usage": usage_normalized, # Option C: Normalized LiteLLM format + "usage_raw": usage_raw, # Option A: Raw provider format (if captured) + "response_content": response_text, + "raw_response": raw_response, + "hidden_params": hidden_params, + "callback_data": callback_data, + } + + +def capture_stats_for_provider( + provider: str, + modality: str, + messages: list[dict], + model: Model, +) -> dict[str, Any]: + """ + Capture stats for a provider by making two requests with a delay. + + Args: + provider: Provider name (for logging and file naming) + modality: Modality name + messages: List of message dicts + model: Model enum + + Returns dict with: + - first_request: stats from first request + - second_request: stats from second request (should show cache hits) + """ + # Generate a unique cache key for OpenAI (ensures both requests hit the same cache shard) + # Reference: capture_provider_stats.py and PromptManager.__init__ + openai_cache_key = f"pz-test-{uuid.uuid4().hex[:12]}" if provider in ("openai", "openai-audio", "azure") else None + + print(" First request...") + first_stats = call_litellm_api(messages, model, provider, cache_key=openai_cache_key) + print(f" Usage: {first_stats['usage']}") + + print(" Waiting 20 seconds for cache to be available...") + time.sleep(20) + + print(" Second request (should show cache hits)...") + second_stats = call_litellm_api(messages, model, provider, cache_key=openai_cache_key) + print(f" Usage: {second_stats['usage']}") + + return { + "provider": provider, + "model": model.value, + "modality": modality, + "first_request": first_stats, + "second_request": second_stats, + } + + +def save_stats(stats: dict[str, Any], output_dir: str, provider: str, modality: str) -> str: + """Save stats to a JSON file.""" + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{provider}_{modality}.json") + + with open(output_path, "w") as f: + json.dump(stats, f, indent=2, default=str) + + return output_path + + +def main(): + """Capture LiteLLM stats for supported provider/modality combinations.""" + parser = argparse.ArgumentParser( + description="Capture token/cost statistics from LLM providers via LiteLLM.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f""" +Available providers: {', '.join(PROVIDER_MODALITY_SUPPORT.keys())} +Available modalities: text-only, image-only, audio-only, text-image, text-audio, image-audio, text-image-audio + +Examples: + python capture_litellm_stats.py # Run all providers/modalities + python capture_litellm_stats.py -p openai # Run all modalities for OpenAI + python capture_litellm_stats.py -p openai -m text-only # Run only text-only for OpenAI + python capture_litellm_stats.py -p anthropic -m text-image # Run only text-image for Anthropic + """ + ) + parser.add_argument( + "-p", "--provider", + nargs="+", + choices=list(PROVIDER_MODALITY_SUPPORT.keys()), + help="Provider(s) to run. If not specified, runs all providers.", + ) + parser.add_argument( + "-m", "--modality", + nargs="+", + choices=["text-only", "image-only", "audio-only", "text-image", "text-audio", "image-audio", "text-image-audio"], + help="Modality(ies) to run. If not specified, runs all supported modalities for each provider.", + ) + args = parser.parse_args() + + messages_dir = os.path.join( + os.path.dirname(__file__), + "..", + "tests", + "pytest", + "data", + "generator_messages", + ) + messages_dir = os.path.abspath(messages_dir) + + output_dir = os.path.join( + os.path.dirname(__file__), + "litellm_stats", + ) + output_dir = os.path.abspath(output_dir) + + print(f"Loading messages from: {messages_dir}") + print(f"Saving stats to: {output_dir}\n") + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Determine which providers to run + providers_to_run = args.provider if args.provider else list(PROVIDER_MODALITY_SUPPORT.keys()) + print(f"Providers to run: {providers_to_run}\n") + + for provider in providers_to_run: + config = PROVIDER_MODALITY_SUPPORT[provider] + model = config["model"] + supported_modalities = config["supported_modalities"] + + # Filter modalities if specified + if args.modality: + modalities_to_run = [m for m in args.modality if m in supported_modalities] + if not modalities_to_run: + print(f"\nProvider: {provider} - SKIPPED (none of {args.modality} supported)") + continue + else: + modalities_to_run = supported_modalities + + print(f"\nProvider: {provider} (model: {model.value})") + print(f" Modalities to run: {modalities_to_run}") + + for modality in modalities_to_run: + print(f"\n Processing modality: {modality}") + + try: + messages = load_messages(modality, provider, messages_dir) + print(f" Loaded {len(messages)} messages from {modality}_{provider}.json") + + stats = capture_stats_for_provider(provider, modality, messages, model) + + output_path = save_stats(stats, output_dir, provider, modality) + print(f" Saved to: {output_path}") + + except FileNotFoundError as e: + print(f" SKIPPED: Message file not found - {e}") + except Exception as e: + print(f" ERROR: {e}") + import traceback + traceback.print_exc() + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/scripts/capture_provider_stats.py b/scripts/capture_provider_stats.py new file mode 100644 index 000000000..5b91c84e8 --- /dev/null +++ b/scripts/capture_provider_stats.py @@ -0,0 +1,830 @@ +#!/usr/bin/env python3 +""" +Script to directly invoke LLM providers and capture token/cost statistics. + +This script: +1. Loads messages from JSON files generated by generate_test_messages.py +2. Sends requests directly to each provider's API (not through litellm) +3. Saves the token/cost related stats returned by the provider +4. Waits 10 seconds +5. Sends the request again and saves the second set of stats + +This allows us to establish baseline expectations for what the providers return, +which can then be used to validate the palimpzest generator's stats tracking. + +Supported providers: +- Anthropic: claude-sonnet-4-5-20250929 (text, image, text+image) +- Google/Vertex AI: gemini-2.5-flash (all seven modality combinations) +- OpenAI: gpt-4o-2024-08-06 (text, image, text+image) +- OpenAI: gpt-4o-audio-preview (text+audio, audio) +- Azure: gpt-4o-2024-08-06 via Azure OpenAI (text, image, text+image) + +Output files are saved to: tests/pytest/scripts/provider_stats/ +""" + +import argparse +import base64 +import json +import os +import sys +import time +import uuid +from typing import Any + + +def detect_image_media_type(base64_data: str) -> str: + """ + Detect the actual image format from base64 data by examining the magic bytes. + + Args: + base64_data: Base64-encoded image data. + + Returns: + The detected media type (e.g., 'image/png', 'image/jpeg'). + Defaults to 'image/jpeg' if format cannot be determined. + """ + try: + # Decode first few bytes to check magic numbers + header = base64.b64decode(base64_data[:32]) + + # PNG: 89 50 4E 47 0D 0A 1A 0A + if header[:8] == b"\x89PNG\r\n\x1a\n": + return "image/png" + # JPEG: FF D8 FF + if header[:3] == b"\xff\xd8\xff": + return "image/jpeg" + # GIF: GIF87a or GIF89a + if header[:6] in (b"GIF87a", b"GIF89a"): + return "image/gif" + # WebP: RIFF....WEBP + if header[:4] == b"RIFF" and header[8:12] == b"WEBP": + return "image/webp" + except Exception: + pass + + return "image/jpeg" + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + +# ============================================================================= +# PROVIDER CONFIGURATIONS +# ============================================================================= +PROVIDER_MODALITY_SUPPORT = { + "anthropic": { + "model": "claude-sonnet-4-5-20250929", + "supported_modalities": ["text-only", "image-only", "text-image"], + # Note: Anthropic does not support audio + }, + "openai": { + "model": "gpt-4o-2024-08-06", + "supported_modalities": ["text-only", "image-only", "text-image"], + }, + "openai-audio": { + "model": "gpt-4o-audio-preview", + "supported_modalities": ["audio-only", "text-audio"], + }, + "gemini": { + "model": "gemini-2.5-flash", + "supported_modalities": [ + "text-only", + "image-only", + "audio-only", + "text-image", + "text-audio", + "image-audio", + "text-image-audio", + ], + }, + "vertex_ai": { + "model": "gemini-2.5-flash", + "supported_modalities": [ + "text-only", + "image-only", + "audio-only", + "text-image", + "text-audio", + "image-audio", + "text-image-audio", + ], + }, + "azure": { + "model": "gpt-4o-2024-08-06", + "supported_modalities": ["text-only", "image-only", "text-image"], + }, +} + + +def load_messages(modality: str, provider: str, messages_dir: str) -> list[dict]: + """Load messages from JSON file for a given modality/provider combination.""" + filepath = os.path.join(messages_dir, f"{modality}_{provider}.json") + with open(filepath) as f: + return json.load(f) + + +def transform_messages_for_openai(messages: list[dict]) -> list[dict]: + """ + Transform palimpzest/litellm message format to OpenAI API format. + + OpenAI expects: + - system messages with string content + - user messages with string content or array of content parts + + Input messages may have content as string or list of content blocks. + """ + openai_messages = [] + + for msg in messages: + role = msg.get("role") + msg_type = msg.get("type") + content = msg.get("content") + + if role == "system": + # System content may be string or list of content blocks + if isinstance(content, list): + # Extract text from content blocks + text_parts = [block.get("text", "") for block in content if block.get("type") == "text"] + openai_messages.append({"role": "system", "content": "".join(text_parts)}) + else: + openai_messages.append({"role": "system", "content": content}) + + elif role == "user": + if msg_type == "text": + # Content may be string or list of content blocks + if isinstance(content, list): + # Already content blocks - add them directly + content_parts = [] + for block in content: + # Convert to OpenAI format (remove cache_control if present) + openai_block = {"type": block.get("type", "text")} + if block.get("type") == "text": + openai_block["text"] = block.get("text", "") + content_parts.append(openai_block) + + if openai_messages and openai_messages[-1]["role"] == "user": + existing_content = openai_messages[-1]["content"] + if isinstance(existing_content, str): + openai_messages[-1]["content"] = [ + {"type": "text", "text": existing_content} + ] + content_parts + else: + existing_content.extend(content_parts) + else: + openai_messages.append({"role": "user", "content": content_parts}) + else: + # String content + if openai_messages and openai_messages[-1]["role"] == "user": + existing_content = openai_messages[-1]["content"] + if isinstance(existing_content, str): + openai_messages[-1]["content"] = [ + {"type": "text", "text": existing_content}, + {"type": "text", "text": content}, + ] + else: + existing_content.append({"type": "text", "text": content}) + else: + openai_messages.append({"role": "user", "content": content}) + + elif msg_type == "image": + # Image content + image_parts = [] + for img in content: + if img.get("type") == "image_url": + image_parts.append(img) + + if openai_messages and openai_messages[-1]["role"] == "user": + existing_content = openai_messages[-1]["content"] + if isinstance(existing_content, str): + openai_messages[-1]["content"] = [ + {"type": "text", "text": existing_content} + ] + image_parts + else: + existing_content.extend(image_parts) + else: + openai_messages.append({"role": "user", "content": image_parts}) + + elif msg_type == "input_audio": + # Audio content + audio_parts = [] + for audio in content: + if audio.get("type") == "input_audio": + audio_parts.append(audio) + + if openai_messages and openai_messages[-1]["role"] == "user": + existing_content = openai_messages[-1]["content"] + if isinstance(existing_content, str): + openai_messages[-1]["content"] = [ + {"type": "text", "text": existing_content} + ] + audio_parts + else: + existing_content.extend(audio_parts) + else: + openai_messages.append({"role": "user", "content": audio_parts}) + + return openai_messages + + +def transform_messages_for_anthropic(messages: list[dict]) -> tuple[str | None, list[dict]]: + """ + Transform palimpzest/litellm message format to Anthropic API format. + + Input messages may already have cache_control markers from PromptManager. + This function preserves those markers while converting to Anthropic's native format. + + Anthropic expects: + - system as a separate parameter (not in messages) + - user/assistant messages with content as array of content blocks + - cache_control markers for caching (preserved from input) + """ + system_prompt = None + anthropic_messages = [] + + for msg in messages: + role = msg.get("role") + msg_type = msg.get("type") + content = msg.get("content") + + if role == "system": + # Anthropic uses system as a separate parameter + # Content may already be a list of content blocks with cache_control + if isinstance(content, list): + # Already in content block format (from PromptManager) + system_prompt = content + else: + # String content - wrap in content block with cache_control + system_prompt = [ + { + "type": "text", + "text": content, + "cache_control": {"type": "ephemeral"}, + } + ] + + elif role == "user": + if msg_type == "text": + # Content may be string or list of content blocks + if isinstance(content, list): + # Already content blocks (may have cache_control) - preserve them + for block in content: + if anthropic_messages and anthropic_messages[-1]["role"] == "user": + anthropic_messages[-1]["content"].append(block) + else: + anthropic_messages.append({"role": "user", "content": [block]}) + else: + # String content + content_block = {"type": "text", "text": content} + if anthropic_messages and anthropic_messages[-1]["role"] == "user": + anthropic_messages[-1]["content"].append(content_block) + else: + anthropic_messages.append({"role": "user", "content": [content_block]}) + + elif msg_type == "image": + # Image content - Anthropic uses base64 format + for img in content: + if img.get("type") == "image_url": + url = img["image_url"]["url"] + if url.startswith("data:"): + # Extract base64 data + _, data = url.split(";base64,") + # Detect actual media type from image data (in case URL has wrong type) + media_type = detect_image_media_type(data) + image_block = { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": data, + }, + } + # Preserve cache_control if present on the original block + if "cache_control" in img: + image_block["cache_control"] = img["cache_control"] + if anthropic_messages and anthropic_messages[-1]["role"] == "user": + anthropic_messages[-1]["content"].append(image_block) + else: + anthropic_messages.append({"role": "user", "content": [image_block]}) + + return system_prompt, anthropic_messages + + +def transform_messages_for_gemini(messages: list[dict]) -> tuple[str | None, list[dict]]: + """ + Transform palimpzest/litellm message format to Gemini API format. + + Gemini expects: + - role: "user" or "model" + - parts: list of content parts + + Input messages may have content as string or list of content blocks. + """ + gemini_contents = [] + system_instruction = None + + for msg in messages: + role = msg.get("role") + msg_type = msg.get("type") + content = msg.get("content") + + if role == "system": + # Gemini uses system_instruction + # Content may be string or list of content blocks + if isinstance(content, list): + # Extract text from content blocks + text_parts = [block.get("text", "") for block in content if block.get("type") == "text"] + system_instruction = "".join(text_parts) + else: + system_instruction = content + + elif role == "user": + parts = [] + + if msg_type == "text": + # Content may be string or list of content blocks + if isinstance(content, list): + for block in content: + if block.get("type") == "text": + parts.append({"text": block.get("text", "")}) + else: + parts.append({"text": content}) + + elif msg_type == "image": + for img in content: + if img.get("type") == "image_url": + url = img["image_url"]["url"] + if url.startswith("data:"): + _, data = url.split(";base64,") + # Detect actual media type from image data + media_type = detect_image_media_type(data) + parts.append({ + "inline_data": { + "mime_type": media_type, + "data": data, + } + }) + + elif msg_type == "input_audio": + for audio in content: + if audio.get("type") == "input_audio": + audio_data = audio["input_audio"] + parts.append({ + "inline_data": { + "mime_type": f"audio/{audio_data.get('format', 'wav')}", + "data": audio_data["data"], + } + }) + + if parts: + if gemini_contents and gemini_contents[-1]["role"] == "user": + gemini_contents[-1]["parts"].extend(parts) + else: + gemini_contents.append({"role": "user", "parts": parts}) + + return system_instruction, gemini_contents + + +def call_openai_api(messages: list[dict], model: str, cache_key: str | None = None) -> dict[str, Any]: + """ + Call OpenAI API directly and return usage statistics. + + Args: + messages: List of message dicts + model: Model name + cache_key: Optional prompt_cache_key for sticky routing to same cache shard + + Returns dict with: + - completion_tokens + - prompt_tokens + - prompt_tokens_details (cached_tokens, text_tokens, image_tokens, audio_tokens) + - total_tokens + """ + import openai + + client = openai.OpenAI() + + openai_messages = transform_messages_for_openai(messages) + + kwargs = {"model": model, "messages": openai_messages, "temperature": 0.0} + + # Check if this is an audio model + if "audio" in model: + kwargs["modalities"] = ["text"] + + # Add prompt_cache_key for caching (ensures requests route to same cache shard) + if cache_key: + kwargs["extra_body"] = {"prompt_cache_key": cache_key} + + response = client.chat.completions.create(**kwargs) + + # Extract complete usage stats + usage_dict = {} + if response.usage: + usage_dict = response.usage.model_dump() + + # Get response text safely + try: + response_text = response.choices[0].message.content[:200] if response.choices and response.choices[0].message.content else None + except Exception: + response_text = None + + # Serialize the full response + try: + raw_response = response.model_dump() + except Exception: + raw_response = str(response) + + return { + "provider": "openai", + "model": model, + "usage": usage_dict, + "response_content": response_text, + "raw_response": raw_response, + } + + +# NOTE: this function was generated speculatively and has not been tested, so it may have errors +def call_azure_api(messages: list[dict], model: str, cache_key: str | None = None) -> dict[str, Any]: + """ + Call Azure OpenAI API directly and return usage statistics. + + Uses the same message format as OpenAI, but routes through Azure endpoints. + + Args: + messages: List of message dicts + model: Model name (deployment name) + cache_key: Optional prompt_cache_key for sticky routing to same cache shard + + Returns dict with: + - completion_tokens + - prompt_tokens + - prompt_tokens_details (cached_tokens, text_tokens, image_tokens, audio_tokens) + - total_tokens + """ + import openai + + api_key = os.environ.get("AZURE_API_KEY") or os.environ.get("AZURE_OPENAI_API_KEY") + azure_endpoint = os.environ.get("AZURE_API_BASE") + api_version = os.environ.get("AZURE_API_VERSION", "2024-12-01-preview") + + if not api_key: + raise ValueError("AZURE_API_KEY or AZURE_OPENAI_API_KEY must be set") + if not azure_endpoint: + raise ValueError("AZURE_API_BASE must be set") + + client = openai.AzureOpenAI( + api_key=api_key, + azure_endpoint=azure_endpoint, + api_version=api_version, + ) + + openai_messages = transform_messages_for_openai(messages) + + kwargs = {"model": model, "messages": openai_messages, "temperature": 0.0} + + # Add prompt_cache_key for caching (ensures requests route to same cache shard) + if cache_key: + kwargs["extra_body"] = {"prompt_cache_key": cache_key} + + response = client.chat.completions.create(**kwargs) + + # Extract complete usage stats + usage_dict = {} + if response.usage: + usage_dict = response.usage.model_dump() + + # Get response text safely + try: + response_text = response.choices[0].message.content[:200] if response.choices and response.choices[0].message.content else None + except Exception: + response_text = None + + # Serialize the full response + try: + raw_response = response.model_dump() + except Exception: + raw_response = str(response) + + return { + "provider": "azure", + "model": model, + "usage": usage_dict, + "response_content": response_text, + "raw_response": raw_response, + } + + +def call_anthropic_api(messages: list[dict], model: str) -> dict[str, Any]: + """ + Call Anthropic API directly and return usage statistics. + + Returns dict with: + - input_tokens + - output_tokens + - cache_creation_input_tokens + - cache_read_input_tokens + """ + import anthropic + + client = anthropic.Anthropic() + + system_prompt, anthropic_messages = transform_messages_for_anthropic(messages) + + response = client.messages.create( + model=model, + max_tokens=1024, + system=system_prompt, + messages=anthropic_messages, + ) + + # Extract complete usage stats + usage_dict = {} + if response.usage: + usage_dict = response.usage.model_dump() + + # Get response text safely + try: + response_text = response.content[0].text[:200] if response.content and response.content[0].text else None + except Exception: + response_text = None + + # Serialize the full response + try: + raw_response = response.model_dump() + except Exception: + raw_response = str(response) + + return { + "provider": "anthropic", + "model": model, + "usage": usage_dict, + "response_content": response_text, + "raw_response": raw_response, + } + + +def call_gemini_api(messages: list[dict], model: str, use_vertex: bool = False) -> dict[str, Any]: + """ + Call Gemini API directly and return usage statistics. + + Args: + messages: List of message dicts + model: Model name + use_vertex: If True, use Vertex AI; otherwise use Google AI Studio + + Returns dict with usage statistics. + """ + from google import genai + from google.genai import types + + system_instruction, gemini_contents = transform_messages_for_gemini(messages) + + # Create client for Google AI Studio or Vertex AI + if use_vertex: + # Vertex AI requires project and location + import os + client = genai.Client( + vertexai=True, + project=os.environ.get("GOOGLE_CLOUD_PROJECT", os.environ.get("VERTEXAI_PROJECT")), + location=os.environ.get("GOOGLE_CLOUD_LOCATION", os.environ.get("VERTEXAI_LOCATION", "us-central1")), + ) + else: + # Google AI Studio uses API key from environment + client = genai.Client() + + # Build the config + config = types.GenerateContentConfig( + temperature=0.0, + system_instruction=system_instruction if system_instruction else None, + ) + + response = client.models.generate_content( + model=model, + contents=gemini_contents, + config=config, + ) + + # Extract complete usage stats from usage_metadata + usage_metadata = response.usage_metadata + usage_dict = {} + if usage_metadata: + # Try model_dump() first (Pydantic models), then to_dict(), then manual extraction + try: + usage_dict = usage_metadata.model_dump() + except AttributeError: + try: + usage_dict = usage_metadata.to_dict() + except AttributeError: + # Manual extraction of known Gemini usage fields + usage_dict = { + "prompt_token_count": getattr(usage_metadata, "prompt_token_count", None), + "candidates_token_count": getattr(usage_metadata, "candidates_token_count", None), + "total_token_count": getattr(usage_metadata, "total_token_count", None), + "cached_content_token_count": getattr(usage_metadata, "cached_content_token_count", None), + } + + # Get response text safely + try: + response_text = response.text[:200] if response.text else None + except Exception: + response_text = None + + # Serialize the full response + try: + # Try model_dump() first (Pydantic models) + raw_response = response.model_dump() + except AttributeError: + try: + raw_response = response.to_dict() + except AttributeError: + # Manual serialization + try: + raw_response = { + "text": response.text if hasattr(response, "text") else None, + "candidates": [ + { + "content": { + "parts": [{"text": getattr(part, "text", str(part))} for part in c.content.parts] if c.content and c.content.parts else [], + "role": c.content.role if c.content else None, + }, + "finish_reason": str(c.finish_reason) if hasattr(c, "finish_reason") else None, + } + for c in (response.candidates or []) + ], + "usage_metadata": usage_dict, + "model_version": getattr(response, "model_version", None), + } + except Exception as e: + raw_response = {"error": str(e), "response_str": str(response)} + + return { + "provider": "vertex_ai" if use_vertex else "gemini", + "model": model, + "usage": usage_dict, + "response_content": response_text, + "raw_response": raw_response, + } + + +def capture_stats_for_provider( + provider: str, + modality: str, + messages: list[dict], + model: str, +) -> dict[str, Any]: + """ + Capture stats for a provider by making two requests with a 10-second gap. + + Returns dict with: + - first_request: stats from first request + - second_request: stats from second request (should show cache hits) + """ + # Generate a unique cache key for OpenAI/Azure (ensures both requests hit the same cache shard) + openai_cache_key = f"pz-test-{uuid.uuid4().hex[:12]}" if provider in ("openai", "openai-audio", "azure") else None + + print(" First request...") + if provider == "openai" or provider == "openai-audio": + first_stats = call_openai_api(messages, model, cache_key=openai_cache_key) + elif provider == "azure": + first_stats = call_azure_api(messages, model, cache_key=openai_cache_key) + elif provider == "anthropic": + first_stats = call_anthropic_api(messages, model) + elif provider == "gemini": + first_stats = call_gemini_api(messages, model, use_vertex=False) + elif provider == "vertex_ai": + first_stats = call_gemini_api(messages, model, use_vertex=True) + else: + raise ValueError(f"Unknown provider: {provider}") + + print(f" Usage: {first_stats['usage']}") + + print(" Waiting 20 seconds for cache to be available...") + time.sleep(20) + + print(" Second request (should show cache hits)...") + if provider == "openai" or provider == "openai-audio": + second_stats = call_openai_api(messages, model, cache_key=openai_cache_key) + elif provider == "azure": + second_stats = call_azure_api(messages, model, cache_key=openai_cache_key) + elif provider == "anthropic": + second_stats = call_anthropic_api(messages, model) + elif provider == "gemini": + second_stats = call_gemini_api(messages, model, use_vertex=False) + elif provider == "vertex_ai": + second_stats = call_gemini_api(messages, model, use_vertex=True) + + print(f" Usage: {second_stats['usage']}") + + return { + "provider": provider, + "model": model, + "modality": modality, + "first_request": first_stats, + "second_request": second_stats, + } + + +def save_stats(stats: dict[str, Any], output_dir: str, provider: str, modality: str) -> str: + """Save stats to a JSON file.""" + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{provider}_{modality}.json") + + with open(output_path, "w") as f: + json.dump(stats, f, indent=2) + + return output_path + + +def main(): + """Capture provider stats for supported provider/modality combinations.""" + parser = argparse.ArgumentParser( + description="Capture token/cost statistics from LLM providers.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=f""" +Available providers: {', '.join(PROVIDER_MODALITY_SUPPORT.keys())} +Available modalities: text-only, image-only, audio-only, text-image, text-audio, image-audio, text-image-audio + +Examples: + python capture_provider_stats.py # Run all providers/modalities + python capture_provider_stats.py -p openai # Run all modalities for OpenAI + python capture_provider_stats.py -p openai -m text-only # Run only text-only for OpenAI + python capture_provider_stats.py -p anthropic -m text-image # Run only text-image for Anthropic + """ + ) + parser.add_argument( + "-p", "--provider", + nargs="+", + choices=list(PROVIDER_MODALITY_SUPPORT.keys()), + help="Provider(s) to run. If not specified, runs all providers.", + ) + parser.add_argument( + "-m", "--modality", + nargs="+", + choices=["text-only", "image-only", "audio-only", "text-image", "text-audio", "image-audio", "text-image-audio"], + help="Modality(ies) to run. If not specified, runs all supported modalities for each provider.", + ) + args = parser.parse_args() + + messages_dir = os.path.join( + os.path.dirname(__file__), + "..", + "tests", + "pytest", + "data", + "generator_messages", + ) + messages_dir = os.path.abspath(messages_dir) + + output_dir = os.path.join( + os.path.dirname(__file__), + "provider_stats", + ) + output_dir = os.path.abspath(output_dir) + + print(f"Loading messages from: {messages_dir}") + print(f"Saving stats to: {output_dir}\n") + + # Determine which providers to run + providers_to_run = args.provider if args.provider else list(PROVIDER_MODALITY_SUPPORT.keys()) + print(f"Providers to run: {providers_to_run}\n") + + for provider in providers_to_run: + config = PROVIDER_MODALITY_SUPPORT[provider] + model = config["model"] + supported_modalities = config["supported_modalities"] + + # Filter modalities if specified + if args.modality: + modalities_to_run = [m for m in args.modality if m in supported_modalities] + if not modalities_to_run: + print(f"\nProvider: {provider} - SKIPPED (none of {args.modality} supported)") + continue + else: + modalities_to_run = supported_modalities + + print(f"\nProvider: {provider} (model: {model})") + print(f" Modalities to run: {modalities_to_run}") + + for modality in modalities_to_run: + print(f"\n Processing modality: {modality}") + + try: + messages = load_messages(modality, provider, messages_dir) + print(f" Loaded {len(messages)} messages from {modality}_{provider}.json") + + stats = capture_stats_for_provider(provider, modality, messages, model) + + output_path = save_stats(stats, output_dir, provider, modality) + print(f" Saved to: {output_path}") + + except FileNotFoundError as e: + print(f" SKIPPED: Message file not found - {e}") + except Exception as e: + print(f" ERROR: {e}") + import traceback + traceback.print_exc() + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_test_messages.py b/scripts/generate_test_messages.py new file mode 100644 index 000000000..e6d68c94b --- /dev/null +++ b/scripts/generate_test_messages.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +Script to generate test messages for each provider/modality combination. + +This script uses the Generator class directly to create message payloads. +It uses the 'generating_messages_only' flag to retrieve the exact messages +that would be sent to the provider without making an actual API call. + +Supported provider/modality combinations: +- Anthropic: text-only, image-only, text-image (no audio support) +- OpenAI: text-only, image-only, text-image +- OpenAI-Audio: audio-only, text-audio +- Gemini: all 7 modality combinations +- Vertex AI: all 7 modality combinations +- Azure: text-only, image-only, text-image + +Output files are saved to: tests/pytest/data/generator_messages/ +Format: {modality}_{provider}.json (e.g., text-only_anthropic.json) +""" + +import json +import os +import sys + +from pydantic import BaseModel, Field + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from palimpzest.constants import Model, PromptStrategy +from palimpzest.core.elements.records import DataRecord +from palimpzest.core.lib.schemas import AudioFilepath, ImageFilepath, union_schemas +from palimpzest.query.generators.generators import Generator + + +def generate_session_id(provider: str, modality: str) -> str: + """ + Generate a unique 12-character session ID for a provider/modality combination. + This ensures each modality test has a unique prompt prefix, preventing cross-modality cache hits. + The ID is deterministic based on provider+modality so regenerating produces consistent results. + """ + import hashlib + hash_input = f"{provider}_{modality}" + hash_hex = hashlib.md5(hash_input.encode()).hexdigest() + return hash_hex[:12].upper() + +STATIC_CONTEXT = """ +WILDLIFE CONSERVATION & RESEARCH CENTER: SPECIES IDENTIFICATION MANUAL (v2025.1) + +SECTION 1: INTRODUCTION AND MISSION +The Wildlife Conservation & Research Center (WCRC) is dedicated to the preservation, study, and rehabilitation of diverse wildlife species. +All staff members, researchers, and volunteers must adhere to these protocols for accurate species identification and data collection. +Our mission combines advanced biological sciences with conservation efforts to protect endangered and threatened populations worldwide. + +SECTION 2: MAMMAL IDENTIFICATION PROTOCOLS + +2.1 ELEPHANTS (Family Elephantidae): + - African Savanna Elephant: Larger ears (shaped like Africa), concave back, two fingers on trunk tip. Weight: 5,000-14,000 lbs. + - African Forest Elephant: Smaller stature, oval-shaped ears, straighter tusks pointing downward. + - Asian Elephant: Smaller ears, convex back, one finger on trunk tip, twin domes on head. Weight: 4,000-11,000 lbs. + - Vocalizations: Trumpeting (alarm/excitement), rumbling (long-distance communication), roaring (distress). + +2.2 BIG CATS (Family Felidae): + - Lion (Panthera leo): Tawny coat, males have distinctive mane. Social, live in prides. Height: 3.5-4 ft at shoulder. + - Tiger (Panthera tigris): Orange coat with black stripes, white underbelly. Solitary hunters. Largest cat species. + - Leopard (Panthera pardus): Golden-yellow coat with rosette patterns. Excellent climbers, often cache prey in trees. + - Cheetah (Acinonyx jubatus): Spotted coat, black "tear marks" from eyes to mouth. Fastest land animal (70 mph). + - Vocalizations: Roaring (lions, tigers, leopards), chirping/purring (cheetahs cannot roar). + +2.3 BEARS (Family Ursidae): + - Brown Bear (Ursus arctos): Large shoulder hump, dish-shaped face, long claws. Includes grizzly subspecies. + - Black Bear (Ursus americanus): Straight facial profile, no shoulder hump, shorter claws. Most common North American bear. + - Polar Bear (Ursus maritimus): White fur, longer neck, smaller ears. Marine mammal adapted to Arctic conditions. + - Giant Panda (Ailuropoda melanoleuca): Black and white coloring, feeds almost exclusively on bamboo. + - Vocalizations: Roaring, growling, huffing, jaw-popping (threat displays). + +2.4 PRIMATES (Order Primates): + - Gorilla: Largest primate, silver-back males, knuckle-walking locomotion. Vocalizations include chest-beating, hooting. + - Chimpanzee: Highly intelligent, uses tools, complex social structures. Vocalizations: pant-hoots, screams. + - Orangutan: Red-orange fur, arboreal lifestyle, solitary. Long calls can travel over 1 km. + - Gibbon: Smaller apes, brachiation locomotion, distinctive whooping songs for territorial marking. + +SECTION 3: BIRD IDENTIFICATION PROTOCOLS + +3.1 RAPTORS (Order Accipitriformes/Falconiformes): + - Bald Eagle: White head and tail, yellow beak. Wingspan: 6-7.5 ft. Call: high-pitched chattering. + - Golden Eagle: Dark brown plumage, golden nape. Powerful hunters of small mammals. + - Peregrine Falcon: Blue-gray back, barred underparts. Fastest bird in dive (240+ mph). + - Red-tailed Hawk: Brown back, pale underparts, distinctive red tail. Most common North American hawk. + +3.2 PARROTS (Order Psittaciformes): + - Macaw: Large, colorful, long tail feathers. Powerful curved beaks. Highly social and vocal. + - African Grey: Gray plumage, red tail. Exceptional mimicry and cognitive abilities. + - Cockatoo: White or pink plumage, distinctive crest. Loud screeching vocalizations. + +SECTION 4: REPTILE IDENTIFICATION PROTOCOLS + +4.1 CROCODILIANS (Order Crocodilia): + - American Alligator: Broad, U-shaped snout, dark coloration. Freshwater habitats. + - Nile Crocodile: V-shaped snout, aggressive. Can reach 16-18 ft in length. + - Gharial: Extremely narrow snout, fish-eating specialist. Critically endangered. + +4.2 LARGE SNAKES (Families Pythonidae/Boidae): + - Reticulated Python: Longest snake species (up to 23 ft), complex geometric patterns. + - Green Anaconda: Heaviest snake species, olive-green with black spots. Semi-aquatic. + - King Cobra: Longest venomous snake (up to 18 ft), distinctive hood when threatened. + +SECTION 5: DATA COLLECTION AND ANALYSIS + +5.1 Visual Identification: + - Document body shape, size, coloration, and distinctive markings. + - Note behavioral characteristics and habitat context. + - Use standardized photography protocols for pattern matching. + +5.2 Audio Identification: + - Record vocalizations with frequency analysis equipment. + - Tag recordings with behavioral context (territorial, mating, alarm, social). + - Cross-reference with vocalization databases for species confirmation. + +5.3 Biometric Data: + - Record body measurements according to species-specific protocols. + - Document age indicators (teeth wear, plumage, etc.). + - Collect genetic samples when possible for lineage verification. + +You are an AI Research Assistant for the WCRC. Your job is to analyze data inputs (text descriptions, images, and/or audio recordings) and identify the species based on the characteristics described in this manual. +Analyze all provided inputs and determine the most likely species identification. +""" + +class TextInputSchema(BaseModel): + """Schema for text-only input.""" + text: str = Field(description="Description of an animal") + age: int = Field(description="The age of the animal in years") + + +class ImageInputSchema(BaseModel): + """Schema for image-only input.""" + image_file: ImageFilepath = Field(description="File path to an image of an animal") + height: float = Field(description="The estimated height of the animal in cm") + + +class AudioInputSchema(BaseModel): + """Schema for audio-only input.""" + audio_file: AudioFilepath = Field(description="File path to an audio recording of an animal") + year: float = Field(description="The year the recording was made") + + +# Union schemas for multi-modal inputs +TextImageInputSchema = union_schemas([TextInputSchema, ImageInputSchema]) +TextAudioInputSchema = union_schemas([TextInputSchema, AudioInputSchema]) +ImageAudioInputSchema = union_schemas([ImageInputSchema, AudioInputSchema]) +TextImageAudioInputSchema = union_schemas([TextInputSchema, ImageInputSchema, AudioInputSchema]) + + +class OutputSchema(BaseModel): + """Output schema for animal identification.""" + animal: str = Field(description="The animal in the input") + +MODALITY_CONFIGS = { + "text-only": { + "input_schema": TextInputSchema, + "data_item": { + "text": "An elephant is a large gray animal with a trunk and big ears. It makes a trumpeting sound.", + "age": 15, + }, + }, + "image-only": { + "input_schema": ImageInputSchema, + "data_item": { + "image_file": "tests/pytest/data/elephant.png", + "height": 304.5, + }, + }, + "audio-only": { + "input_schema": AudioInputSchema, + "data_item": { + "audio_file": "tests/pytest/data/elephant.wav", + "year": 2020, + }, + }, + "text-image": { + "input_schema": TextImageInputSchema, + "data_item": { + "text": "An elephant is a large gray animal with a trunk and big ears. It makes a trumpeting sound.", + "age": 15, + "image_file": "tests/pytest/data/elephant.png", + "height": 304.5, + }, + }, + "text-audio": { + "input_schema": TextAudioInputSchema, + "data_item": { + "text": "An elephant is a large gray animal with a trunk and big ears. It makes a trumpeting sound.", + "age": 15, + "audio_file": "tests/pytest/data/elephant.wav", + "year": 2020, + }, + }, + "image-audio": { + "input_schema": ImageAudioInputSchema, + "data_item": { + "image_file": "tests/pytest/data/elephant.png", + "height": 304.5, + "audio_file": "tests/pytest/data/elephant.wav", + "year": 2020, + }, + }, + "text-image-audio": { + "input_schema": TextImageAudioInputSchema, + "data_item": { + "text": "An elephant is a large gray animal with a trunk and big ears. It makes a trumpeting sound.", + "age": 15, + "image_file": "tests/pytest/data/elephant.png", + "height": 304.5, + "audio_file": "tests/pytest/data/elephant.wav", + "year": 2020, + }, + }, +} + +# Maps provider name to (Model enum, supported modalities) +PROVIDER_CONFIGS = { + "anthropic": { + "model": Model.CLAUDE_4_5_SONNET, + "supported_modalities": ["text-only", "image-only", "text-image"], + }, + "openai": { + "model": Model.GPT_4o, + "supported_modalities": ["text-only", "image-only", "text-image"], + }, + "openai-audio": { + "model": Model.GPT_4o_AUDIO_PREVIEW, + "supported_modalities": ["audio-only", "text-audio"], + }, + "gemini": { + "model": Model.GOOGLE_GEMINI_2_5_FLASH, + "supported_modalities": [ + "text-only", "image-only", "audio-only", + "text-image", "text-audio", "image-audio", "text-image-audio", + ], + }, + "vertex_ai": { + "model": Model.GEMINI_2_5_FLASH, + "supported_modalities": [ + "text-only", "image-only", "audio-only", + "text-image", "text-audio", "image-audio", "text-image-audio", + ], + }, + "azure": { + "model": Model.AZURE_GPT_4o, + "supported_modalities": ["text-only", "image-only", "text-image"], + }, +} + + +def save_messages(modality: str, provider: str, messages: list[dict], output_dir: str) -> str: + """ + Save messages to a JSON file. + + Args: + modality: Modality name + provider: Provider name + messages: List of message dicts + output_dir: Directory to save files + + Returns: + Path to the saved file + """ + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"{modality}_{provider}.json") + + # Convert messages to JSON-serializable format + serializable_messages = [] + for msg in messages: + serializable_msg = msg.copy() + serializable_messages.append(serializable_msg) + + with open(output_path, "w") as f: + json.dump(serializable_messages, f, indent=2, default=str) + + return output_path + + +def main(): + """Generate and save messages for all provider/modality combinations.""" + # Ensure the output directory follows the repository structure + output_dir = os.path.join( + os.path.dirname(__file__), + "..", + "tests", + "pytest", + "data", + "generator_messages", + ) + output_dir = os.path.abspath(output_dir) + + # Count total combinations + total_combinations = sum( + len(provider_config["supported_modalities"]) + for provider_config in PROVIDER_CONFIGS.values() + ) + + print(f"Generating test messages for {total_combinations} provider/modality combinations...") + print(f"Output directory: {output_dir}") + print(f"Static context length: ~{len(STATIC_CONTEXT.split())} words\n") + + generated_count = 0 + + for provider, provider_config in PROVIDER_CONFIGS.items(): + model = provider_config["model"] + supported_modalities = provider_config["supported_modalities"] + + print(f"Provider: {provider} (model: {model.value})") + print(f" Supported modalities: {supported_modalities}") + + for modality in supported_modalities: + config = MODALITY_CONFIGS[modality] + print(f" Generating: {modality}_{provider}") + + try: + # Prepare input record + input_schema = config["input_schema"] + data_item = config["data_item"] + input_record = DataRecord(input_schema(**data_item), source_indices=[0]) + + # Instantiate Generator + generator = Generator( + model=model, + prompt_strategy=PromptStrategy.MAP, + reasoning_effort=None, + desc=STATIC_CONTEXT, + ) + + # Generate unique session ID for this provider/modality to prevent cross-modality cache hits + session_id = generate_session_id(provider, modality) + + # Call the generator with the new flag + # Pass cache_isolation_id to inject session ID at start of system/user prompts + messages = generator( + candidate=input_record, + fields=OutputSchema.model_fields, + output_schema=OutputSchema, + generating_messages_only=True, + cache_isolation_id=session_id, + ) + + # Manually save the messages using local helper + output_path = save_messages(modality, provider, messages, output_dir) + + print(f" Session ID: {session_id}") + print(f" Saved to: {output_path}") + print(f" Messages: {len(messages)}") + + # Print message summary + for i, msg in enumerate(messages): + role = msg.get("role", "unknown") + msg_type = msg.get("type", "unknown") + content = msg.get("content", "") + content_len = len(content) if isinstance(content, str) else len(str(content)) + print(f" [{i}] role={role}, type={msg_type}, len={content_len}") + + generated_count += 1 + + except Exception as e: + print(f" ERROR: {e}") + import traceback + traceback.print_exc() + + print() + + print(f"Done! Generated {generated_count}/{total_combinations} message files.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/update_model_info.py b/scripts/update_model_info.py new file mode 100644 index 000000000..9a2830b4e --- /dev/null +++ b/scripts/update_model_info.py @@ -0,0 +1,668 @@ +#!/usr/bin/env python3 +""" +Script to automatically update pz_models_information.json with data from external sources. + +Data Sources: +- LiteLLM proxy /model/info endpoint: Dynamic model info (100% accuracy, prioritized) +- LiteLLM model_prices_and_context_window.json: Cost and capability data (fallback) +- MMLU-Pro leaderboard: Quality scores (fuzzy matching acceptable) +- Artificial Analysis: Latency data (fuzzy matching acceptable) + +Usage: + python scripts/update_model_info.py MODEL_ID [MODEL_ID ...] [--use-endpoint] +""" + +import argparse +import json +import os +import socket +import subprocess +import sys +import time +from typing import Any + +import requests +import yaml + +# Add src to path to import from palimpzest +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +from palimpzest.utils.model_info_helpers import ( + LATENCY_TPS_DATA, + MMLU_PRO_SCORES, + derive_model_flags, + fuzzy_match_score, +) + +# Constants +LITELLM_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +PZ_MODELS_PATH = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "palimpzest", + "utils", + "pz_models_information.json", +) + +# Provider mapping from LiteLLM prefixes to our provider strings +PROVIDER_MAPPING = { + "openai": "openai", + "anthropic": "anthropic", + "claude": "anthropic", + "vertex_ai": "vertex_ai", + "gemini": "gemini", + "together_ai": "together_ai", + "together": "together_ai", + "hosted_vllm": "hosted_vllm", + "groq": "groq", + "mistral": "mistral", + "cohere": "cohere", + "bedrock": "bedrock", + "azure": "azure", + "deepseek": "deepseek", + "fireworks_ai": "fireworks_ai", + "xai": "xai", +} + +# API key environment variable mapping +API_KEY_MAPPING = { + "openai": "OPENAI_API_KEY", + "azure": "AZURE_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", + "vertex_ai": "GOOGLE_APPLICATION_CREDENTIALS", + "gemini": "GEMINI_API_KEY", + "together_ai": "TOGETHER_API_KEY", + "hosted_vllm": "VLLM_API_KEY", + "groq": "GROQ_API_KEY", + "mistral": "MISTRAL_API_KEY", + "cohere": "COHERE_API_KEY", + "deepseek": "DEEPSEEK_API_KEY", + "fireworks_ai": "FIREWORKS_API_KEY", + "xai": "XAI_API_KEY", +} + +# Field mapping from LiteLLM endpoint to PZ format +FIELD_MAPPING = [ + ("usd_per_input_token", "input_cost_per_token", None), + ("usd_per_output_token", "output_cost_per_token", None), + ("usd_per_audio_input_token", "input_cost_per_audio_token", None), + ("usd_per_audio_output_token", "output_cost_per_audio_token", None), + ("usd_per_image_output_token", "output_cost_per_image_token", None), + ("usd_per_cache_read_token", "cache_read_input_token_cost", None), + ("usd_per_cache_creation_token", "cache_creation_input_token_cost", None), + ("supports_prompt_caching", "supports_prompt_caching", False), +] + +# Boolean capability fields derived from endpoint +CAPABILITY_MAPPING = [ + ("is_vision_model", "supports_vision", False), + ("is_audio_model", "supports_audio_input", False), + ("is_reasoning_model", "supports_reasoning", False), +] + +# MMLU_PRO_SCORES, LATENCY_TPS_DATA, and fuzzy_match_score are imported from +# palimpzest.utils.model_info_helpers + +# Alias for backwards compatibility in this script +LATENCY_DATA = LATENCY_TPS_DATA + + +# ============================================================================= +# LiteLLM Proxy Endpoint Functions +# ============================================================================= + +def get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def extract_provider(model_id: str) -> str: + """Extract provider from model ID.""" + if "/" in model_id: + prefix = model_id.split("/")[0].lower() + return PROVIDER_MAPPING.get(prefix, prefix) + + model_lower = model_id.lower() + + # OpenAI + if any(x in model_lower for x in ["gpt", "o1-", "o3-", "o4-", "dall-e", "whisper"]): + return "openai" + + # Anthropic + if "claude" in model_lower: + return "anthropic" + + # Google (Vertex AI / Gemini) + if "gemini" in model_lower or "bison" in model_lower: + return "vertex_ai" + + # Meta / Together / Llama + if "llama" in model_lower: + return "together_ai" + + # Mistral + if "mistral" in model_lower or "mixtral" in model_lower: + return "mistral" + + # DeepSeek + if "deepseek" in model_lower: + return "deepseek" + + return "unknown" + + +def get_api_key_env_var(provider: str) -> str | None: + return API_KEY_MAPPING.get(provider) + + +def generate_config_yaml(model_ids: list[str]) -> str: + config_id = 0 + config_filename = f"litellm_config_{config_id}.yaml" + while not os.path.exists(config_filename): + config_id += 1 + + config_list = [] + for model_id in model_ids: + provider = extract_provider(model_id) + env_var_name = get_api_key_env_var(provider) + api_key_val = f"os.environ/{env_var_name}" if env_var_name else None + + entry = { + "model_name": model_id, + "litellm_params": { + "model": model_id, + "api_key": api_key_val, + }, + } + config_list.append(entry) + + yaml_structure = {"model_list": config_list} + with open(config_filename, "w") as f: + yaml.dump(yaml_structure, f, default_flow_style=False, sort_keys=False) + + return config_filename + + +def fetch_dynamic_model_info(model_ids: list[str]) -> dict[str, Any]: + if not model_ids: + return {} + + port = get_free_port() + proxy_url = f"http://127.0.0.1:{port}" + config_filename = generate_config_yaml(model_ids) + server_env = os.environ.copy() + process = None + dynamic_model_info = {} + + print(f"Starting LiteLLM proxy on port {port} for {len(model_ids)} models...") + + try: + process = subprocess.Popen( + ["litellm", "--config", config_filename, "--port", str(port)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=server_env, + ) + + server_ready = False + max_retries = 30 + for i in range(max_retries): + if process.poll() is not None: + _, stderr = process.communicate() + print(f" LiteLLM process died unexpectedly: {stderr.decode()}") + break + try: + requests.get(f"{proxy_url}/health/readiness", timeout=1) + server_ready = True + print(f" Server ready after {i + 1} attempts") + break + except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout): + time.sleep(0.5) + + if not server_ready: + print(" Timeout: LiteLLM server failed to start within the limit.") + return {} + + try: + response = requests.get(f"{proxy_url}/model/info", timeout=10) + response.raise_for_status() + model_data = response.json() + + if "data" in model_data and len(model_data["data"]) > 0: + for item in model_data["data"]: + model_name = item.get("model_name") + model_info = item.get("model_info", {}) + dynamic_model_info[model_name] = model_info + print(f" Retrieved info for: {model_name}") + else: + print(" WARNING: No model data returned from endpoint") + except Exception as e: + print(f" Error fetching model info: {e}") + + finally: + if process: + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + + if os.path.exists(config_filename): + os.remove(config_filename) + + return dynamic_model_info + + +# ============================================================================= +# Data Fetching Functions +# ============================================================================= + +def fetch_litellm_data() -> dict[str, Any]: + print(f"Fetching LiteLLM data from {LITELLM_URL}...") + try: + response = requests.get(LITELLM_URL, timeout=30) + response.raise_for_status() + data = response.json() + print(f" Found {len(data)} models in LiteLLM database") + return data + except Exception as e: + print(f" Error fetching LiteLLM data: {e}") + return {} + + +def load_existing_data() -> dict[str, Any]: + if os.path.exists(PZ_MODELS_PATH): + with open(PZ_MODELS_PATH) as f: + return json.load(f) + return {} + + +def save_data(data: dict[str, Any]) -> None: + with open(PZ_MODELS_PATH, "w") as f: + json.dump(data, f, indent=4) + print(f" [System] Successfully saved to {PZ_MODELS_PATH}") + + +# ============================================================================= +# Matching and Conversion Functions +# ============================================================================= + +# fuzzy_match_score is imported from palimpzest.utils.model_info_helpers + + +def derive_model_flags_with_provider(model_id: str, provider: str) -> dict[str, bool]: + """Wrapper around derive_model_flags that also adds provider-specific flags.""" + flags = derive_model_flags(model_id) + if provider == "hosted_vllm": + flags["is_vllm_model"] = True + return flags + + +# ============================================================================= +# Interactive Review Functions +# ============================================================================= + +def prompt_for_value(field_name: str, current_value: Any, value_type: str = "any") -> Any: + while True: + user_input = input(f" > Enter new value for '{field_name}' (or press Enter to keep current): ").strip() + if user_input == "": + return current_value + + try: + if user_input.lower() == "none": + return None + if value_type == "float": + return float(user_input) + elif value_type == "int": + return int(user_input) + elif value_type == "bool": + return user_input.lower() in ("true", "yes", "1", "y") + else: + try: + return json.loads(user_input) + except json.JSONDecodeError: + return user_input + except ValueError as e: + print(f" Invalid input: {e}. Try again.") + + +def review_field( + field_name: str, + value: Any, + from_endpoint: bool, + interactive: bool = True, + value_type: str = "any" +) -> tuple[Any, bool]: + """ + Review a single field. + Logic: + 1. If from_endpoint is True and value not None -> VERIFIED (return immediately) + 2. If interactive -> Ask User (1. Correct, 2. Incorrect) + """ + if from_endpoint and value is not None: + # Verified automatically by endpoint + return value, False + + if not interactive: + return value, False + + print(f"\n [Review] {field_name}: {value}") + if from_endpoint and value is None: + print(" (Source: Endpoint returned Null)") + else: + print(" (Source: Derived/Static/Fallback)") + + while True: + choice = input(" 1. Yes, information is correct\n 2. No, enter different value\n Choice [1]: ").strip() + if choice == "" or choice == "1": + return value, False + elif choice == "2": + new_value = prompt_for_value(field_name, value, value_type) + return new_value, True + else: + print(" Invalid choice.") + + +def convert_and_review_model( + model_id: str, + litellm_static: dict[str, Any] | None, + litellm_dynamic: dict[str, Any] | None, + existing_entry: dict[str, Any] | None, + interactive: bool = True, +) -> dict[str, Any]: + """ + 1. Aggregates all data into a Draft Entry. + 2. Displays the Draft Entry (User can see Current State). + 3. Iterates fields to Verify (Prioritizing endpoint). + """ + print(f"\n{'='*60}") + print(f"PROCESSING: {model_id}") + print(f"{'='*60}") + + # --- PHASE 1: Build Draft Entry & Source Map --- + + endpoint_fields: set[str] = set() + raw_data: dict[str, Any] = {} + + # 1. Base: Static Data + if litellm_static: + raw_data.update(litellm_static) + + # 2. Overlay: Dynamic Data (Priority) + if litellm_dynamic: + for key, val in litellm_dynamic.items(): + if val is not None: + raw_data[key] = val + endpoint_fields.add(key) + + # 3. Construct Candidate dictionary + candidate = {} + source_map = {} # Map field -> is_from_endpoint + + # Provider + prov = raw_data.get("litellm_provider") or extract_provider(model_id) + candidate["provider"] = prov + source_map["provider"] = "litellm_provider" in endpoint_fields + + # Costs & Caching + for pz_field, litellm_field, default in FIELD_MAPPING: + val = raw_data.get(litellm_field, default) + candidate[pz_field] = val + source_map[pz_field] = litellm_field in endpoint_fields + + # Capabilities + for pz_field, litellm_field, default in CAPABILITY_MAPPING: + val = raw_data.get(litellm_field, default) + # Special logic for audio + if pz_field == "is_audio_model": + audio_in = raw_data.get("supports_audio_input", False) + audio_out = raw_data.get("supports_audio_output", False) + val = audio_in or audio_out + source_map[pz_field] = ("supports_audio_input" in endpoint_fields or + "supports_audio_output" in endpoint_fields) + else: + source_map[pz_field] = litellm_field in endpoint_fields + candidate[pz_field] = val + + # Modes + mode = raw_data.get("mode", "chat") + mode_src = "mode" in endpoint_fields + candidate["is_text_model"] = mode in ["chat", "completion"] + source_map["is_text_model"] = mode_src + candidate["is_embedding_model"] = mode == "embedding" + source_map["is_embedding_model"] = mode_src + + # Flags (Always derived, never endpoint) + flags = derive_model_flags_with_provider(model_id, candidate["provider"]) + for k, v in flags.items(): + candidate[k] = v + source_map[k] = False + + # Scores / Latency (Fuzzy or Existing) + mmlu = fuzzy_match_score(model_id, MMLU_PRO_SCORES) + if mmlu is None and existing_entry: + mmlu = existing_entry.get("MMLU_Pro_score") + candidate["MMLU_Pro_score"] = mmlu + source_map["MMLU_Pro_score"] = False + + tps = fuzzy_match_score(model_id, LATENCY_DATA) + sec_per_tok = round(1.0 / tps, 6) if tps else None + if sec_per_tok is None and existing_entry: + sec_per_tok = existing_entry.get("seconds_per_output_token") + candidate["seconds_per_output_token"] = sec_per_tok + source_map["seconds_per_output_token"] = False + + # Audio Cache Read (check existing) + acr = existing_entry.get("usd_per_audio_cache_read_token") if existing_entry else None + if acr is not None: + candidate["usd_per_audio_cache_read_token"] = acr + source_map["usd_per_audio_cache_read_token"] = False + + # Note + if existing_entry and existing_entry.get("note"): + candidate["note"] = existing_entry["note"] + source_map["note"] = False + + # Sources + src_list = [LITELLM_URL] + if existing_entry and existing_entry.get("sources"): + existing_srcs = existing_entry["sources"] + if isinstance(existing_srcs, list): + src_list = list(set(src_list + existing_srcs)) + elif existing_srcs: + src_list = list(set(src_list + [existing_srcs])) + candidate["sources"] = src_list + + # --- PHASE 2: Display Current State --- + + print("\n--- Current State (Draft) ---") + display_dict = {} + for k, v in candidate.items(): + if k == "sources": + continue + src_label = "ENDPOINT" if source_map.get(k, False) and v is not None else "DERIVED/STATIC" + display_dict[k] = f"{v} [{src_label}]" + + print(json.dumps(display_dict, indent=2)) + print("-" * 30) + + # --- PHASE 3: Verification Loop --- + + final_entry = {} + final_entry["sources"] = candidate["sources"] + + # Iterate over specific keys to ensure order and types + + # Provider + final_entry["provider"], _ = review_field( + "provider", candidate["provider"], source_map["provider"], interactive, "str" + ) + + # All cost/cache fields + for k in [f[0] for f in FIELD_MAPPING] + ["usd_per_audio_cache_read_token"]: + if k in candidate: + vtype = "float" if "usd_" in k else "bool" + final_entry[k], _ = review_field( + k, candidate[k], source_map.get(k, False), interactive, vtype + ) + + # Capabilities & Modes + bool_keys = [f[0] for f in CAPABILITY_MAPPING] + ["is_text_model", "is_embedding_model"] + list(flags.keys()) + for k in bool_keys: + if k in candidate: + final_entry[k], _ = review_field( + k, candidate[k], source_map.get(k, False), interactive, "bool" + ) + + # Stats + final_entry["MMLU_Pro_score"], _ = review_field( + "MMLU_Pro_score", candidate["MMLU_Pro_score"], False, interactive, "float" + ) + final_entry["seconds_per_output_token"], _ = review_field( + "seconds_per_output_token", candidate["seconds_per_output_token"], False, interactive, "float" + ) + + # Note + if "note" in candidate: + final_entry["note"], _ = review_field( + "note", candidate["note"], False, interactive, "str" + ) + + # Cleanup Nulls + cleaned_entry = {k: v for k, v in final_entry.items() if v is not None} + + return cleaned_entry + + +def update_model( + model_id: str, + existing_data: dict[str, Any], + litellm_static: dict[str, Any], + litellm_dynamic: dict[str, Any] | None = None, + interactive: bool = True, +) -> dict[str, Any] | None: + static_entry = None + if model_id in litellm_static: + static_entry = litellm_static[model_id] + else: + if "/" in model_id: + model_name = model_id.split("/", 1)[1] + if model_name in litellm_static: + static_entry = litellm_static[model_name] + + dynamic_entry = litellm_dynamic.get(model_id) if litellm_dynamic else None + + if static_entry is None and dynamic_entry is None: + print(f"\n WARNING: No LiteLLM data found for {model_id}") + + existing_entry = existing_data.get(model_id) + + new_entry = convert_and_review_model( + model_id, + static_entry, + dynamic_entry, + existing_entry, + interactive=interactive, + ) + return new_entry + + +def process_models( + model_ids: list[str], + existing_data: dict[str, Any], + litellm_static: dict[str, Any], + use_endpoint: bool = False, + interactive: bool = True, + skip_existing: bool = False, +) -> None: + """ + Process models and (if interactive is True) ask user whether to write each one to file. + """ + litellm_dynamic = None + if use_endpoint: + litellm_dynamic = fetch_dynamic_model_info(model_ids) + + # We work on the existing_data dictionary directly so we can save incrementally + current_data_state = existing_data.copy() + + for model_id in model_ids: + # Check if model exists and if we should skip it + if skip_existing and model_id in current_data_state: + print(f"\n [System] Model '{model_id}' already exists in file. Skipping.") + continue + + new_entry = update_model( + model_id, current_data_state, litellm_static, litellm_dynamic, + interactive=interactive + ) + + if new_entry: + # Display Final Result + print("\n" + "-"*30) + print(f"FINAL JSON FOR: {model_id}") + print(json.dumps(new_entry, indent=2)) + print("-" * 30) + + # Ask user to write to file + should_save = True + if interactive: + confirm = input(f"Write '{model_id}' to json file? [y/N]: ").strip().lower() + should_save = confirm == 'y' + + if should_save: + current_data_state[model_id] = new_entry + save_data(current_data_state) + else: + print(f" [System] Skipped saving {model_id}.") + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser( + description="Update pz_models_information.json with external data sources", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("model_ids", nargs="*", help="Model IDs to update") + parser.add_argument("--use-endpoint", action="store_true", help="Fetch dynamic info") + parser.add_argument("--non-interactive", action="store_true", help="Skip review and auto-save") + parser.add_argument("--update-all", action="store_true", help="Update all existing") + + args = parser.parse_args() + + litellm_static = fetch_litellm_data() + if not litellm_static: + return + + existing_data = load_existing_data() + + skip_existing = False + if args.update_all: + model_ids = list(existing_data.keys()) + elif args.model_ids: + model_ids = args.model_ids + skip_existing = True + else: + parser.print_help() + return + + interactive = not args.non_interactive + + # Run the main processing loop + process_models( + model_ids, + existing_data, + litellm_static, + use_endpoint=args.use_endpoint, + interactive=interactive, + skip_existing=skip_existing, + ) + + print("\nAll operations complete.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/palimpzest/constants.py b/src/palimpzest/constants.py index 4b0609b96..b8b31b0fc 100644 --- a/src/palimpzest/constants.py +++ b/src/palimpzest/constants.py @@ -1,133 +1,12 @@ ### This file contains constants used by Palimpzest ### +from __future__ import annotations + import os from enum import Enum +import litellm -# ENUMS -class Model(str, Enum): - """ - Model describes the underlying LLM which should be used to perform some operation - which requires invoking an LLM. It does NOT specify whether the model need be executed - remotely or locally (if applicable). - """ - LLAMA3_2_3B = "together_ai/meta-llama/Llama-3.2-3B-Instruct-Turbo" - LLAMA3_1_8B = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" - LLAMA3_3_70B = "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo" - LLAMA3_2_90B_V = "together_ai/meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo" - DEEPSEEK_V3 = "together_ai/deepseek-ai/DeepSeek-V3" - DEEPSEEK_R1_DISTILL_QWEN_1_5B = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" - GPT_4o = "openai/gpt-4o-2024-08-06" - GPT_4o_MINI = "openai/gpt-4o-mini-2024-07-18" - GPT_4_1 = "openai/gpt-4.1-2025-04-14" - GPT_4_1_MINI = "openai/gpt-4.1-mini-2025-04-14" - GPT_4_1_NANO = "openai/gpt-4.1-nano-2025-04-14" - GPT_5 = "openai/gpt-5-2025-08-07" - GPT_5_MINI = "openai/gpt-5-mini-2025-08-07" - GPT_5_NANO = "openai/gpt-5-nano-2025-08-07" - o4_MINI = "openai/o4-mini-2025-04-16" # noqa: N815 - # CLAUDE_3_5_SONNET = "anthropic/claude-3-5-sonnet-20241022" - CLAUDE_3_7_SONNET = "anthropic/claude-3-7-sonnet-20250219" - CLAUDE_3_5_HAIKU = "anthropic/claude-3-5-haiku-20241022" - GEMINI_2_0_FLASH = "vertex_ai/gemini-2.0-flash" - GEMINI_2_5_FLASH = "vertex_ai/gemini-2.5-flash" - GEMINI_2_5_PRO = "vertex_ai/gemini-2.5-pro" - GOOGLE_GEMINI_2_5_FLASH = "gemini/gemini-2.5-flash" - GOOGLE_GEMINI_2_5_FLASH_LITE = "gemini/gemini-2.5-flash-lite" - GOOGLE_GEMINI_2_5_PRO = "gemini/gemini-2.5-pro" - LLAMA_4_MAVERICK = "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas" - GPT_4o_AUDIO_PREVIEW = "openai/gpt-4o-audio-preview" - GPT_4o_MINI_AUDIO_PREVIEW = "openai/gpt-4o-mini-audio-preview" - VLLM_QWEN_1_5_0_5B_CHAT = "hosted_vllm/qwen/Qwen1.5-0.5B-Chat" - # o1 = "o1-2024-12-17" - TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small" - CLIP_VIT_B_32 = "clip-ViT-B-32" - - def __repr__(self): - return f"{self.name}" - - def is_llama_model(self): - return "llama" in self.value.lower() - - def is_clip_model(self): - return "clip" in self.value.lower() - - def is_together_model(self): - return "together_ai" in self.value.lower() or self.is_clip_model() - - def is_text_embedding_model(self): - return "text-embedding" in self.value.lower() - - def is_o_model(self): - return self in [Model.o4_MINI] - - def is_gpt_5_model(self): - return self in [Model.GPT_5, Model.GPT_5_MINI, Model.GPT_5_NANO] - - def is_openai_model(self): - return "openai" in self.value.lower() or self.is_text_embedding_model() - - def is_anthropic_model(self): - return "anthropic" in self.value.lower() - - def is_vertex_model(self): - return "vertex_ai" in self.value.lower() - - def is_google_ai_studio_model(self): - return "gemini/" in self.value.lower() - - def is_vllm_model(self): - return "hosted_vllm" in self.value.lower() - - def is_reasoning_model(self): - reasoning_models = [ - Model.GPT_5, Model.GPT_5_MINI, Model.GPT_5_NANO, Model.o4_MINI, - Model.GEMINI_2_5_PRO, Model.GEMINI_2_5_FLASH, - Model.GOOGLE_GEMINI_2_5_PRO, Model.GOOGLE_GEMINI_2_5_FLASH, Model.GOOGLE_GEMINI_2_5_FLASH_LITE, - Model.CLAUDE_3_7_SONNET, - ] - return self in reasoning_models - - def is_text_model(self): - non_text_models = [ - Model.LLAMA3_2_90B_V, - Model.CLIP_VIT_B_32, Model.TEXT_EMBEDDING_3_SMALL, - Model.GPT_4o_AUDIO_PREVIEW, Model.GPT_4o_MINI_AUDIO_PREVIEW, - ] - return self not in non_text_models - - # TODO: I think SONNET and HAIKU are vision-capable too - def is_vision_model(self): - return self in [ - Model.LLAMA3_2_90B_V, Model.LLAMA_4_MAVERICK, - Model.GPT_4o, Model.GPT_4o_MINI, Model.GPT_4_1, Model.GPT_4_1_MINI, Model.GPT_4_1_NANO, Model.o4_MINI, Model.GPT_5, Model.GPT_5_MINI, Model.GPT_5_NANO, - Model.GEMINI_2_0_FLASH, Model.GEMINI_2_5_FLASH, Model.GEMINI_2_5_PRO, - Model.GOOGLE_GEMINI_2_5_PRO, Model.GOOGLE_GEMINI_2_5_FLASH, Model.GOOGLE_GEMINI_2_5_FLASH_LITE, - ] - - def is_audio_model(self): - return self in [ - Model.GPT_4o_AUDIO_PREVIEW, Model.GPT_4o_MINI_AUDIO_PREVIEW, - Model.GEMINI_2_0_FLASH, Model.GEMINI_2_5_FLASH, Model.GEMINI_2_5_PRO, - Model.GOOGLE_GEMINI_2_5_PRO, Model.GOOGLE_GEMINI_2_5_FLASH, Model.GOOGLE_GEMINI_2_5_FLASH_LITE, - ] - - def is_text_image_multimodal_model(self): - return self in [ - Model.LLAMA_4_MAVERICK, - Model.GPT_4o, Model.GPT_4o_MINI, Model.GPT_4_1, Model.GPT_4_1_MINI, Model.GPT_4_1_NANO, Model.o4_MINI, Model.GPT_5, Model.GPT_5_MINI, Model.GPT_5_NANO, - Model.GEMINI_2_0_FLASH, Model.GEMINI_2_5_FLASH, Model.GEMINI_2_5_PRO, - Model.GOOGLE_GEMINI_2_5_PRO, Model.GOOGLE_GEMINI_2_5_FLASH, Model.GOOGLE_GEMINI_2_5_FLASH_LITE, - ] - - def is_text_audio_multimodal_model(self): - return self in [ - Model.GPT_4o_AUDIO_PREVIEW, Model.GPT_4o_MINI_AUDIO_PREVIEW, - Model.GEMINI_2_0_FLASH, Model.GEMINI_2_5_FLASH, Model.GEMINI_2_5_PRO, - Model.GOOGLE_GEMINI_2_5_PRO, Model.GOOGLE_GEMINI_2_5_FLASH, Model.GOOGLE_GEMINI_2_5_FLASH_LITE, - ] - - def is_embedding_model(self): - return self in [Model.CLIP_VIT_B_32, Model.TEXT_EMBEDDING_3_SMALL] +from palimpzest.utils.model_info_helpers import ModelMetricsManager, predict_local_model_metrics class PromptStrategy(str, Enum): @@ -261,19 +140,19 @@ def log_attempt_number(retry_state): # Palimpzest root directory PZ_DIR = os.path.join(os.path.expanduser("~"), ".palimpzest") -# Assume 500 MB/sec for local SSD scan time +# assume 500 MB/sec for local SSD scan time LOCAL_SCAN_TIME_PER_KB = 1 / (float(500) * 1024) -# Assume 30 GB/sec for sequential access of memory +# assume 30 GB/sec for sequential access of memory MEMORY_SCAN_TIME_PER_KB = 1 / (float(30) * 1024 * 1024) -# Assume 1 KB per record +# assume 1 KB per record NAIVE_BYTES_PER_RECORD = 1024 -# Rough conversion from # of characters --> # of tokens; assumes 1 token ~= 4 chars +# rough conversion from # of characters --> # of tokens; assumes 1 token ~= 4 chars TOKENS_PER_CHARACTER = 0.25 -# Rough estimate of the number of tokens the context is allowed to take up for LLAMA3 models +# rough estimate of the number of tokens the context is allowed to take up for LLAMA3 models LLAMA_CONTEXT_TOKENS_LIMIT = 6000 # a naive estimate for the input record size @@ -303,9 +182,292 @@ def log_attempt_number(retry_state): # a naive estimate of the time it takes to extract the text from a PDF using a PDF processor NAIVE_PDF_PROCESSOR_TIME_PER_RECORD = 10.0 -# Whether or not to log LLM outputs +# whether or not to log LLM outputs LOG_LLM_OUTPUT = False +# maximum number of models to use when user does not narrow optimization space +MAX_AVAILABLE_MODELS = 5 + +class Model: + """ + Model describes the underlying LLM which should be used to perform some operation + which requires invoking an LLM. + """ + # Registry of known models (maps value string to Model instance) + _registry: dict[str, Model] = {} + + def __init__(self, model_id: str, api_base: str | None = None, **vllm_kwargs): + self.metrics_manager = ModelMetricsManager() + self.model_id = model_id + self.api_base = api_base + self.vllm_kwargs = vllm_kwargs if vllm_kwargs else {} + + # For vLLM models (api_base is set), try to get model info from litellm's local data + if api_base is not None: + self.model_specs = self._get_litellm_model_specs(model_id) + else: + self.model_specs = self.metrics_manager.get_model_metrics(model_id) + if not self.model_specs: + raise ValueError("Palimpzest currently does not contain information about this model.") + + Model._registry[model_id] = self + + def _get_litellm_model_specs(self, model_id: str) -> dict: + """Get model specs from litellm's local model_cost data for vLLM models.""" + # Use predict function to get quality, latency metrics, and capability flags + predicted_metrics = predict_local_model_metrics(model_id) + + # Start with defaults, then overlay predicted values + specs = { + "is_text_model": True, + "is_vision_model": False, + "is_llama_model": False, + "is_clip_model": False, + "is_audio_model": False, + "is_reasoning_model": False, + "is_embedding_model": False, + "is_text_image_multimodal_embedding_model": False, + "is_vllm_model": True, # Mark as vLLM model + "usd_per_input_token": 0.0, # Cost always 0 for local model + "usd_per_output_token": 0.0, + "seconds_per_output_token": predicted_metrics["seconds_per_output_token"], + "MMLU_Pro_score": predicted_metrics["MMLU_Pro_score"], + } + + # Overlay all flags detected from model name (including False values like is_text_model for embeddings) + for key, value in predicted_metrics.items(): + if key.startswith("is_"): + specs[key] = value + + # Try litellm for additional capability detection (may not work for local models) + try: + if litellm.supports_vision(model=model_id): + specs["is_vision_model"] = True + except Exception: + pass + + try: + if litellm.supports_audio_input(model=model_id): + specs["is_audio_model"] = True + except Exception: + pass + + return specs + + def __lt__(self, other): + if isinstance(other, Model): + return self.value < other.value + if isinstance(other, str): + return self.value < other + return NotImplemented + + @classmethod + def get_all_models(cls) -> list[Model]: + return list(cls._registry.values()) + + @property + def value(self) -> str: + return self.model_id + + @property + def provider(self) -> str | None: + """Returns the provider string for this model.""" + return self.model_specs.get("provider") + + @property + def api_key_env_var(self) -> str | None: + """ + Returns the standard environment variable name for this provider's API key. + """ + if self.provider == "gemini": + return "GEMINI_API_KEY" if os.getenv("GEMINI_API_KEY") else "GOOGLE_API_KEY" + if self.provider == "azure": + return "AZURE_API_KEY" if os.getenv("AZURE_API_KEY") else "AZURE_OPENAI_API_KEY" + mapping = { + "openai": "OPENAI_API_KEY", + "vertex_ai": "GOOGLE_APPLICATION_CREDENTIALS", + "anthropic": "ANTHROPIC_API_KEY", + "together_ai": "TOGETHER_API_KEY", + "hosted_vllm": "VLLM_API_KEY" + } + return mapping.get(self.provider) + + def __repr__(self) -> str: + return self.value + + def __str__(self) -> str: + return self.value + + def __eq__(self, other: object) -> bool: + if isinstance(other, Model): + return self.value == other.value + if isinstance(other, str): + return self.value == other + return NotImplemented + + def __hash__(self) -> int: + return hash(self.value) + + def is_llama_model(self) -> bool: + return self.model_specs.get("is_llama_model", False) + + def is_vllm_model(self) -> bool: + return self.model_specs.get("is_vllm_model", False) and self.api_base is not None + + def is_embedding_model(self) -> bool: + return self.model_specs.get("is_embedding_model", False) + + def is_text_image_multimodal_embedding_model(self) -> bool: + return self.model_specs.get("is_text_image_multimodal_embedding_model", False) + + def is_provider_vertex_ai(self) -> bool: + return self.provider == "vertex_ai" + + def is_provider_anthropic(self) -> bool: + return self.provider == "anthropic" + + def is_provider_google_ai_studio(self) -> bool: + return self.provider == "gemini" + + def is_provider_openai(self) -> bool: + return self.provider == "openai" + + def is_provider_azure(self) -> bool: + return self.provider == "azure" + + def is_provider_together_ai(self) -> bool: + return self.provider == "together_ai" + + def is_provider_deepseek(self) -> bool: + return self.provider == "deepseek" + + def is_provider_ollama(self) -> bool: + return self.provider == "ollama" + + def is_model_gemini(self) -> bool: + return "gemini" in self.value.lower() + + def get_model_name(self) -> str: + return self.value.split("/")[-1] if "/" in self.value else self.value + + def is_o_model(self) -> bool: + return self.model_specs.get("is_o_model", False) + + def is_gpt_5_model(self) -> bool: + return self.model_specs.get("is_gpt_5_model", False) + + def is_reasoning_model(self) -> bool: + return self.model_specs.get("is_reasoning_model", False) + + def is_text_model(self) -> bool: + return self.model_specs.get("is_text_model", False) + + def is_vision_model(self) -> bool: + return self.model_specs.get("is_vision_model", False) + + def is_audio_model(self) -> bool: + return self.model_specs.get("is_audio_model", False) + + def is_text_image_multimodal_model(self) -> bool: + return self.is_text_model() and self.is_vision_model() + + def is_text_audio_multimodal_model(self) -> bool: + return self.is_audio_model() and self.is_text_model() + + def supports_prompt_caching(self) -> bool: + return (self.is_provider_anthropic() or self.is_provider_google_ai_studio() or self.is_provider_vertex_ai or self.is_provider_openai() or self.is_provider_azure()) \ + and self.model_specs.get("supports_prompt_caching", False) + + def get_usd_per_input_token(self) -> float: + return self.model_specs.get("usd_per_input_token", 0.0) + + def get_usd_per_audio_input_token(self) -> float: + return self.model_specs.get("usd_per_audio_input_token", self.get_usd_per_input_token()) + + # forward-looking, TODO: default value discussion + def get_usd_per_image_input_token(self) -> float: + return self.model_specs.get("usd_per_image_input_token", self.get_usd_per_input_token()) + + def get_usd_per_cache_read_token(self) -> float: + return self.model_specs.get("usd_per_cache_read_token", self.get_usd_per_input_token()) + + def get_usd_per_audio_cache_read_token(self) -> float: + return self.model_specs.get("usd_per_audio_cache_read_token", self.get_usd_per_cache_read_token()) + + def get_usd_per_image_cache_read_token(self) -> float: + return self.model_specs.get("usd_per_image_cache_read_token", self.get_usd_per_cache_read_token()) + + # forward looking; Gemini explicit + def get_usd_per_cached_token_per_hour(self) -> float: + return self.model_specs.get("usd_per_cached_token_per_hour", 0.0) + + def get_usd_per_cache_creation_token(self) -> float: + return self.model_specs.get("usd_per_cache_creation_token", 0.0) + + def get_usd_per_output_token(self) -> float: + return self.model_specs.get("usd_per_output_token", 0.0) + + # forward-looking + def get_usd_per_audio_cache_creation_token(self) -> float: + return self.model_specs.get("usd_per_audio_cache_creation_token", 0.0) + + # forward-looking + def get_usd_per_image_cache_creation_token(self) -> float: + return self.model_specs.get("usd_per_image_cache_creation_token", 0.0) + + def get_seconds_per_output_token(self) -> float: + return self.model_specs.get("seconds_per_output_token", 0.0) + + def get_overall_score(self) -> float: + return self.model_specs.get("MMLU_Pro_score", 0.0) + +# TODO: investigate which (if any llama3 models are still supported by TogetherAI) +# Model.LLAMA3_2_3B = Model("together_ai/meta-llama/Llama-3.2-3B-Instruct-Turbo") - seems to be deprecated +Model.LLAMA3_1_8B = Model("together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo") +Model.LLAMA3_3_70B = Model("together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo") +Model.LLAMA3_2_90B_V = Model("together_ai/meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo") +Model.DEEPSEEK_V3 = Model("together_ai/deepseek-ai/DeepSeek-V3") +Model.DEEPSEEK_R1_DISTILL_QWEN_1_5B = Model("together_ai/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B") +Model.GPT_4o = Model("openai/gpt-4o-2024-08-06") +Model.GPT_4o_MINI = Model("openai/gpt-4o-mini-2024-07-18") +Model.GPT_4_1 = Model("openai/gpt-4.1-2025-04-14") +Model.GPT_4_1_MINI = Model("openai/gpt-4.1-mini-2025-04-14") +Model.GPT_4_1_NANO = Model("openai/gpt-4.1-nano-2025-04-14") +Model.GPT_5 = Model("openai/gpt-5-2025-08-07") +Model.GPT_5_MINI = Model("openai/gpt-5-mini-2025-08-07") +Model.GPT_5_NANO = Model("openai/gpt-5-nano-2025-08-07") +Model.GPT_5_2 = Model("openai/gpt-5.2-2025-12-11") +Model.o4_MINI = Model("openai/o4-mini-2025-04-16") # noqa: N815 +# Model.CLAUDE_3_5_SONNET = Model("anthropic/claude-3-5-sonnet-20241022") - retired 10/28/2025 +Model.CLAUDE_3_7_SONNET = Model("anthropic/claude-3-7-sonnet-20250219") +Model.CLAUDE_4_SONNET = Model("anthropic/claude-sonnet-4-20250514") +Model.CLAUDE_4_5_SONNET = Model("anthropic/claude-sonnet-4-5-20250929") +Model.CLAUDE_3_5_HAIKU = Model("anthropic/claude-3-5-haiku-20241022") +Model.CLAUDE_4_5_HAIKU = Model("anthropic/claude-haiku-4-5-20251001") +Model.GEMINI_3_0_PRO = Model("vertex_ai/gemini-3-pro-preview") # image +Model.GEMINI_3_0_FLASH = Model("vertex_ai/gemini-3-flash-preview") # Text, Image, Video, Audio, and PDF +Model.GEMINI_2_0_FLASH = Model("vertex_ai/gemini-2.0-flash") +Model.GEMINI_2_5_FLASH = Model("vertex_ai/gemini-2.5-flash") +Model.GEMINI_2_5_PRO = Model("vertex_ai/gemini-2.5-pro") +Model.GOOGLE_GEMINI_3_0_PRO = Model("gemini/gemini-3-pro-preview") +Model.GOOGLE_GEMINI_3_0_FLASH = Model("gemini/gemini-3-flash-preview") +Model.GOOGLE_GEMINI_2_5_FLASH = Model("gemini/gemini-2.5-flash") +Model.GOOGLE_GEMINI_2_5_FLASH_LITE = Model("gemini/gemini-2.5-flash-lite") +Model.GOOGLE_GEMINI_2_5_PRO = Model("gemini/gemini-2.5-pro") +Model.LLAMA_4_MAVERICK = Model("vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas") +Model.GPT_4o_AUDIO_PREVIEW = Model("openai/gpt-4o-audio-preview") +Model.GPT_4o_MINI_AUDIO_PREVIEW = Model("openai/gpt-4o-mini-audio-preview") +Model.AZURE_GPT_4o = Model("azure/gpt-4o-2024-08-06") +Model.AZURE_GPT_4o_MINI = Model("azure/gpt-4o-mini-2024-07-18") +Model.AZURE_GPT_4_1 = Model("azure/gpt-4.1-2025-04-14") +Model.AZURE_GPT_4_1_MINI = Model("azure/gpt-4.1-mini-2025-04-14") +Model.AZURE_GPT_4_1_NANO = Model("azure/gpt-4.1-nano-2025-04-14") +Model.AZURE_o4_MINI = Model("azure/o4-mini-2025-04-16") # noqa: N815 +Model.AZURE_GPT_4o_AUDIO_PREVIEW = Model("azure/gpt-4o-audio-preview") +Model.AZURE_GPT_4o_MINI_AUDIO_PREVIEW = Model("azure/gpt-4o-mini-audio-preview") +Model.TEXT_EMBEDDING_3_SMALL = Model("openai/text-embedding-3-small") +Model.CLIP_VIT_B_32 = Model("clip-ViT-B-32") +Model.NOMIC_EMBED_TEXT = Model("ollama/nomic-embed-text") #### MODEL PERFORMANCE & COST METRICS #### # Overall model quality is computed using MMLU-Pro; multi-modal models currently use the same score for vision @@ -321,315 +483,5 @@ def log_attempt_number(retry_state): # from the internet for this quick POC, but we can and should do more to model these # values more precisely: # - https://artificialanalysis.ai/models/llama-3-1-instruct-8b -# -LLAMA3_2_3B_INSTRUCT_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.06 / 1e6, - "usd_per_output_token": 0.06 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0079, - ##### Agg. Benchmark ##### - "overall": 36.50, # https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct/discussions/13 -} -LLAMA3_1_8B_INSTRUCT_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.18 / 1e6, - "usd_per_output_token": 0.18 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0050, - ##### Agg. Benchmark ##### - "overall": 44.25, -} -LLAMA3_3_70B_INSTRUCT_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.88 / 1e6, - "usd_per_output_token": 0.88 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0122, - ##### Agg. Benchmark ##### - "overall": 69.9, -} -LLAMA3_2_90B_V_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 1.2 / 1e6, - "usd_per_output_token": 1.2 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0303, - ##### Agg. Benchmark ##### - "overall": 65.00, # set to be slightly higher than gpt-4o-mini -} -DEEPSEEK_V3_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 1.25 / 1E6, - "usd_per_output_token": 1.25 / 1E6, - ##### Time ##### - "seconds_per_output_token": 0.0114, - ##### Agg. Benchmark ##### - "overall": 73.8, -} -DEEPSEEK_R1_DISTILL_QWEN_1_5B_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.18 / 1E6, - "usd_per_output_token": 0.18 / 1E6, - ##### Time ##### - "seconds_per_output_token": 0.0050, # NOTE: copied to be same as LLAMA3_1_8B_INSTRUCT_MODEL_CARD; need to update when we have data - ##### Agg. Benchmark ##### - "overall": 39.90, # https://www.reddit.com/r/LocalLLaMA/comments/1iserf9/deepseek_r1_distilled_models_mmlu_pro_benchmarks/ -} -GPT_4o_AUDIO_PREVIEW_MODEL_CARD = { - # NOTE: COPYING OVERALL AND SECONDS_PER_OUTPUT_TOKEN FROM GPT_4o; need to update when we have audio-specific benchmarks - ##### Cost in USD ##### - "usd_per_audio_input_token": 2.5 / 1e6, - "usd_per_output_token": 10.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0080, - ##### Agg. Benchmark ##### - "overall": 74.1, -} -GPT_4o_MINI_AUDIO_PREVIEW_MODEL_CARD = { - # NOTE: COPYING OVERALL AND SECONDS_PER_OUTPUT_TOKEN FROM GPT_4o; need to update when we have audio-specific benchmarks - ##### Cost in USD ##### - "usd_per_audio_input_token": 0.15 / 1e6, - "usd_per_output_token": 0.6 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0159, - ##### Agg. Benchmark ##### - "overall": 62.7, -} -GPT_4o_MODEL_CARD = { - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 2.5 / 1e6, - "usd_per_output_token": 10.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0080, - ##### Agg. Benchmark ##### - "overall": 74.1, -} -GPT_4o_MINI_MODEL_CARD = { - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 0.15 / 1e6, - "usd_per_output_token": 0.6 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0159, - ##### Agg. Benchmark ##### - "overall": 62.7, -} -GPT_4_1_MODEL_CARD = { - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 2.0 / 1e6, - "usd_per_output_token": 8.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0076, - ##### Agg. Benchmark ##### - "overall": 80.5, -} -GPT_4_1_MINI_MODEL_CARD = { - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 0.4 / 1e6, - "usd_per_output_token": 1.6 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0161, - ##### Agg. Benchmark ##### - "overall": 77.2, -} -GPT_4_1_NANO_MODEL_CARD = { - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 0.1 / 1e6, - "usd_per_output_token": 0.4 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0060, - ##### Agg. Benchmark ##### - "overall": 62.3, -} -GPT_5_MODEL_CARD = { - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 1.25 / 1e6, - "usd_per_output_token": 10.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0060, - ##### Agg. Benchmark ##### - "overall": 87.00, -} -GPT_5_MINI_MODEL_CARD = { - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 0.25 / 1e6, - "usd_per_output_token": 2.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0135, - ##### Agg. Benchmark ##### - "overall": 82.50, -} -GPT_5_NANO_MODEL_CARD = { - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 0.05 / 1e6, - "usd_per_output_token": 0.4 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0055, - ##### Agg. Benchmark ##### - "overall": 77.9, -} -o4_MINI_MODEL_CARD = { # noqa: N816 - # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves - ##### Cost in USD ##### - "usd_per_input_token": 1.1 / 1e6, - "usd_per_output_token": 4.4 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0092, - ##### Agg. Benchmark ##### - "overall": 80.6, # using number reported for o3-mini; true number is likely higher -} -# o1_MODEL_CARD = { # noqa: N816 -# # NOTE: it is unclear if the same ($ / token) costs can be applied for vision, or if we have to calculate this ourselves -# ##### Cost in USD ##### -# "usd_per_input_token": 15 / 1e6, -# "usd_per_output_token": 60 / 1e6, -# ##### Time ##### -# "seconds_per_output_token": 0.0110, -# ##### Agg. Benchmark ##### -# "overall": 83.50, -# } -TEXT_EMBEDDING_3_SMALL_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.02 / 1e6, - "usd_per_output_token": None, - ##### Time ##### - "seconds_per_output_token": 0.0098, # NOTE: just copying GPT_4o_MINI_MODEL_CARD for now - ##### Agg. Benchmark ##### - "overall": 63.09, # NOTE: just copying GPT_4o_MINI_MODEL_CARD for now -} -CLIP_VIT_B_32_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.00, - "usd_per_output_token": None, - ##### Time ##### - "seconds_per_output_token": 0.0098, # NOTE: just copying TEXT_EMBEDDING_3_SMALL_MODEL_CARD for now - ##### Agg. Benchmark ##### - "overall": 63.3, # NOTE: imageNet top-1 accuracy -} -CLAUDE_3_5_SONNET_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 3.0 / 1e6, - "usd_per_output_token": 15.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0154, - ##### Agg. Benchmark ##### - "overall": 78.4, -} -CLAUDE_3_7_SONNET_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 3.0 / 1e6, - "usd_per_output_token": 15.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0156, - ##### Agg. Benchmark ##### - "overall": 80.7, -} -CLAUDE_3_5_HAIKU_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.8 / 1e6, - "usd_per_output_token": 4.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0189, - ##### Agg. Benchmark ##### - "overall": 64.1, -} -GEMINI_2_0_FLASH_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.15 / 1e6, - "usd_per_output_token": 0.6 / 1e6, - "usd_per_audio_input_token": 1.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0054, - ##### Agg. Benchmark ##### - "overall": 77.40, -} -GEMINI_2_5_FLASH_LITE_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.1 / 1e6, - "usd_per_output_token": 0.4 / 1e6, - "usd_per_audio_input_token": 0.3 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0034, - ##### Agg. Benchmark ##### - "overall": 79.1, # NOTE: interpolated between gemini 2.5 flash and gemini 2.0 flash -} -GEMINI_2_5_FLASH_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.30 / 1e6, - "usd_per_output_token": 2.5 / 1e6, - "usd_per_audio_input_token": 1.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0044, - ##### Agg. Benchmark ##### - "overall": 80.75, # NOTE: interpolated between gemini 2.0 flash and gemini 2.5 pro -} -GEMINI_2_5_PRO_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 1.25 / 1e6, - "usd_per_output_token": 10.0 / 1e6, - "usd_per_audio_input_token": 1.25 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0072, - ##### Agg. Benchmark ##### - "overall": 84.10, -} -LLAMA_4_MAVERICK_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.35 / 1e6, - "usd_per_output_token": 1.15 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.0122, - ##### Agg. Benchmark ##### - "overall": 79.4, -} -VLLM_QWEN_1_5_0_5B_CHAT_MODEL_CARD = { - ##### Cost in USD ##### - "usd_per_input_token": 0.0 / 1e6, - "usd_per_output_token": 0.0 / 1e6, - ##### Time ##### - "seconds_per_output_token": 0.1000, # TODO: fill-in with a better estimate - ##### Agg. Benchmark ##### - "overall": 30.0, # TODO: fill-in with a better estimate -} - -MODEL_CARDS = { - Model.LLAMA3_2_3B.value: LLAMA3_2_3B_INSTRUCT_MODEL_CARD, - Model.LLAMA3_1_8B.value: LLAMA3_1_8B_INSTRUCT_MODEL_CARD, - Model.LLAMA3_3_70B.value: LLAMA3_3_70B_INSTRUCT_MODEL_CARD, - Model.LLAMA3_2_90B_V.value: LLAMA3_2_90B_V_MODEL_CARD, - Model.DEEPSEEK_V3.value: DEEPSEEK_V3_MODEL_CARD, - Model.DEEPSEEK_R1_DISTILL_QWEN_1_5B.value: DEEPSEEK_R1_DISTILL_QWEN_1_5B_MODEL_CARD, - Model.GPT_4o.value: GPT_4o_MODEL_CARD, - Model.GPT_4o_MINI.value: GPT_4o_MINI_MODEL_CARD, - Model.GPT_4o_AUDIO_PREVIEW.value: GPT_4o_AUDIO_PREVIEW_MODEL_CARD, - Model.GPT_4o_MINI_AUDIO_PREVIEW.value: GPT_4o_MINI_AUDIO_PREVIEW_MODEL_CARD, - Model.GPT_4_1.value: GPT_4_1_MODEL_CARD, - Model.GPT_4_1_MINI.value: GPT_4_1_MINI_MODEL_CARD, - Model.GPT_4_1_NANO.value: GPT_4_1_NANO_MODEL_CARD, - Model.GPT_5.value: GPT_5_MODEL_CARD, - Model.GPT_5_MINI.value: GPT_5_MINI_MODEL_CARD, - Model.GPT_5_NANO.value: GPT_5_NANO_MODEL_CARD, - Model.o4_MINI.value: o4_MINI_MODEL_CARD, - # Model.o1.value: o1_MODEL_CARD, - Model.TEXT_EMBEDDING_3_SMALL.value: TEXT_EMBEDDING_3_SMALL_MODEL_CARD, - Model.CLIP_VIT_B_32.value: CLIP_VIT_B_32_MODEL_CARD, - # Model.CLAUDE_3_5_SONNET.value: CLAUDE_3_5_SONNET_MODEL_CARD, - Model.CLAUDE_3_7_SONNET.value: CLAUDE_3_7_SONNET_MODEL_CARD, - Model.CLAUDE_3_5_HAIKU.value: CLAUDE_3_5_HAIKU_MODEL_CARD, - Model.GEMINI_2_0_FLASH.value: GEMINI_2_0_FLASH_MODEL_CARD, - Model.GEMINI_2_5_FLASH.value: GEMINI_2_5_FLASH_MODEL_CARD, - Model.GEMINI_2_5_PRO.value: GEMINI_2_5_PRO_MODEL_CARD, - Model.GOOGLE_GEMINI_2_5_FLASH.value: GEMINI_2_5_FLASH_MODEL_CARD, - Model.GOOGLE_GEMINI_2_5_FLASH_LITE.value: GEMINI_2_5_FLASH_LITE_MODEL_CARD, - Model.GOOGLE_GEMINI_2_5_PRO.value: GEMINI_2_5_PRO_MODEL_CARD, - Model.LLAMA_4_MAVERICK.value: LLAMA_4_MAVERICK_MODEL_CARD, - Model.VLLM_QWEN_1_5_0_5B_CHAT.value: VLLM_QWEN_1_5_0_5B_CHAT_MODEL_CARD, -} +# +# Model metrics now fetched from singular json file curated_model_info.json diff --git a/src/palimpzest/core/data/dataset.py b/src/palimpzest/core/data/dataset.py index 25cff1d02..d6e1853db 100644 --- a/src/palimpzest/core/data/dataset.py +++ b/src/palimpzest/core/data/dataset.py @@ -592,11 +592,10 @@ def sem_agg(self, col: dict | type[BaseModel], agg: str, depends_on: str | list[ # construct new output schema new_output_schema = None if isinstance(col, dict): - col_schema = create_schema_from_fields([col]) - new_output_schema = union_schemas([self.schema, col_schema]) + new_output_schema = create_schema_from_fields([col]) elif issubclass(col, BaseModel): assert len(col.model_fields) == 1, "For semantic aggregation, when passing a BaseModel to `col` it must have exactly one field." - new_output_schema = union_schemas([self.schema, col]) + new_output_schema = col else: raise ValueError("`col` must be a dictionary or a single-field BaseModel.") @@ -717,6 +716,7 @@ def optimize_and_run(self, config: QueryProcessorConfig | None = None, train_dat policy = construct_policy_from_kwargs(**kwargs) if policy is not None: kwargs["policy"] = policy + config.policy = policy # construct unique logical op ids for all operators in this dataset self._generate_unique_logical_op_ids() diff --git a/src/palimpzest/core/models.py b/src/palimpzest/core/models.py index a54ba484b..d5d7ff34b 100644 --- a/src/palimpzest/core/models.py +++ b/src/palimpzest/core/models.py @@ -18,36 +18,30 @@ class GenerationStats(BaseModel): # The raw answer as output from the generator (a list of strings, possibly of len 1) # raw_answers: Optional[List[str]] = field(default_factory=list) - # the number of input audio tokens - input_audio_tokens: int = 0 + # the total number of input text tokens processed by this operator; None if this operation did not use any LLM + # typed as a float because GenerationStats may be amortized (i.e. divided) acorss a number of output records + input_text_tokens: float = 0.0 - # the number of input text tokens - input_text_tokens: int = 0 + # the total number of input audio tokens processed by this operation. + input_audio_tokens: float = 0.0 - # the number of input image tokens - input_image_tokens: int = 0 + # the total number of input image tokens processed by this operation. + input_image_tokens: float = 0.0 - # the total number of input tokens processed by this operator; None if this operation did not use an LLM - # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records - total_input_tokens: float = 0.0 - - # the total number of output tokens processed by this operator; None if this operation did not use an LLM - # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records - total_output_tokens: float = 0.0 + # the total number of cache read tokens processed by this operation (charged at a discount, typically 0.1x input rate) + cache_read_tokens: float = 0.0 - # the total number of input tokens processed by embedding models - total_embedding_input_tokens: float = 0.0 - - # the total cost of processing the input tokens; None if this operation did not use an LLM - total_input_cost: float = 0.0 + # the total number of tokens written to the cache in this operation (Anthropic only) (charged at creation rate, typically 1.25x input rate) + cache_creation_tokens: float = 0.0 - # the total cost of processing the output tokens; None if this operation did not use an LLM - total_output_cost: float = 0.0 + # the number of output text tokens generated by the model + output_text_tokens: float = 0.0 - # the total cost of processing input tokens for embedding models - total_embedding_cost: float = 0.0 + # the total number of input tokens processed by embedding models + embedding_input_tokens: float = 0.0 # the total cost of processing the input and output tokens; None if this operation did not use an LLM + # TODO: future PR: cost_per_record --> total_cost cost_per_record: float = 0.0 # (if applicable) the time (in seconds) spent executing a call to an LLM @@ -57,48 +51,24 @@ class GenerationStats(BaseModel): fn_call_duration_secs: float = 0.0 # (if applicable) the total number of LLM calls made by this operator - total_llm_calls: float = 0 + total_llm_calls: float = 0.0 # (if applicable) the total number of embedding LLM calls made by this operator - total_embedding_llm_calls: float = 0 + total_embedding_llm_calls: float = 0.0 def __iadd__(self, other: GenerationStats) -> GenerationStats: - # self.raw_answers.extend(other.raw_answers) - for model_field in [ - "total_input_tokens", - "total_output_tokens", - "total_input_cost", - "total_output_cost", - "cost_per_record", - "llm_call_duration_secs", - "fn_call_duration_secs", - "total_llm_calls", - "total_embedding_llm_calls", - "total_embedding_input_tokens", - "total_embedding_cost" - - ]: - setattr(self, model_field, getattr(self, model_field) + getattr(other, model_field)) + for field in type(self).model_fields: + if field == "model_name": + continue + setattr(self, field, getattr(self, field) + getattr(other, field)) return self def __add__(self, other: GenerationStats) -> GenerationStats: dct = { field: getattr(self, field) + getattr(other, field) - for field in [ - "total_input_tokens", - "total_output_tokens", - "total_input_cost", - "total_output_cost", - "llm_call_duration_secs", - "fn_call_duration_secs", - "cost_per_record", - "total_llm_calls", - "total_embedding_llm_calls", - "total_embedding_input_tokens", - "total_embedding_cost" - ] + for field in type(self).model_fields + if field != "model_name" } - # dct['raw_answers'] = self.raw_answers + other.raw_answers dct["model_name"] = self.model_name return GenerationStats(**dct) @@ -108,20 +78,10 @@ def __itruediv__(self, quotient: float) -> GenerationStats: raise ZeroDivisionError("Cannot divide by zero") if isinstance(quotient, int): quotient = float(quotient) - for model_field in [ - "total_input_tokens", - "total_output_tokens", - "total_input_cost", - "total_output_cost", - "cost_per_record", - "llm_call_duration_secs", - "fn_call_duration_secs", - "total_llm_calls", - "total_embedding_llm_calls", - "total_embedding_input_tokens", - "total_embedding_cost" - ]: - setattr(self, model_field, getattr(self, model_field) / quotient) + for field in type(self).model_fields: + if field == "model_name": + continue + setattr(self, field, getattr(self, field) / quotient) return self def __truediv__(self, quotient: float) -> GenerationStats: @@ -131,19 +91,8 @@ def __truediv__(self, quotient: float) -> GenerationStats: quotient = float(quotient) dct = { field: getattr(self, field) / quotient - for field in [ - "total_input_tokens", - "total_output_tokens", - "total_input_cost", - "total_output_cost", - "llm_call_duration_secs", - "fn_call_duration_secs", - "total_llm_calls", - "total_embedding_llm_calls", - "cost_per_record", - "total_embedding_input_tokens", - "total_embedding_cost" - ] + for field in type(self).model_fields + if field != "model_name" } dct["model_name"] = self.model_name return GenerationStats(**dct) @@ -224,26 +173,27 @@ class RecordOpStats(BaseModel): # (if applicable) the list of generated fields for this record generated_fields: list[str] | None = None - # the total number of input tokens processed by this operator; None if this operation did not use an LLM + # the number of input text tokens processed by this operation # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records - total_input_tokens: float = 0.0 + input_text_tokens: float = 0.0 - # the total number of output tokens processed by this operator; None if this operation did not use an LLM - # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records - total_output_tokens: float = 0.0 + # the number of input audio tokens processed by this operation + input_audio_tokens: float = 0.0 - # the total number of input tokens processed by embedding models - # typed as a float because GenerationStats may be amortized (i.e. divided) across a number of output records - total_embedding_input_tokens: float = 0.0 + # the number of input image tokens processed by this operation + input_image_tokens: float = 0.0 + + # the number of cache read tokens processed by this operation + cache_read_tokens: float = 0.0 - # the total cost of processing the input tokens; None if this operation did not use an LLM - total_input_cost: float = 0.0 + # the number of tokens written to cache by this operation + cache_creation_tokens: float = 0.0 - # the total cost of processing the output tokens; None if this operation did not use an LLM - total_output_cost: float = 0.0 + # the number of output text tokens generated by this operation + output_text_tokens: float = 0.0 - # the (possibly amortized) cost of generating embeddings for this record; None if this operation did not use an embedding LLM - total_embedding_cost: float = 0.0 + # the number of input tokens processed by embedding models + embedding_input_tokens: float = 0.0 # (if applicable) the filter text (or a string representation of the filter function) applied to this record filter_str: str | None = None @@ -262,10 +212,10 @@ class RecordOpStats(BaseModel): fn_call_duration_secs: float = 0.0 # (if applicable) the total number of LLM calls made by this operator - total_llm_calls: float = 0 + total_llm_calls: float = 0.0 # (if applicable) the total number of embedding LLM calls made by this operator - total_embedding_llm_calls: float = 0 + total_embedding_llm_calls: float = 0.0 # (if applicable) a boolean indicating whether this is the statistics captured from a failed convert operation failed_convert: bool | None = None @@ -291,14 +241,26 @@ class OperatorStats(BaseModel): # the total cost of this operation total_op_cost: float = 0.0 - # the total input tokens processed by this operation - total_input_tokens: int = 0 + # the number of input text tokens processed by this operation + input_text_tokens: float = 0.0 + + # the number of input audio tokens processed by this operation + input_audio_tokens: float = 0.0 + + # the number of input image tokens processed by this operation + input_image_tokens: float = 0.0 + + # the number of cache read tokens processed by this operation + cache_read_tokens: float = 0.0 - # the total output tokens processed by this operation - total_output_tokens: int = 0 + # the number of tokens written to cache by this operation + cache_creation_tokens: float = 0.0 - #the total embedding input tokens processed by this operation - total_embedding_input_tokens: int = 0 + # the number of output text tokens generated by this operation + output_text_tokens: float = 0.0 + + # the number of input tokens processed by embedding models + embedding_input_tokens: float = 0.0 # a list of RecordOpStats processed by the operation record_op_stats_lst: list[RecordOpStats] = Field(default_factory=list) @@ -329,9 +291,13 @@ def __iadd__(self, stats: OperatorStats | RecordOpStats) -> OperatorStats: if isinstance(stats, OperatorStats): self.total_op_time += stats.total_op_time self.total_op_cost += stats.total_op_cost - self.total_input_tokens += stats.total_input_tokens - self.total_output_tokens += stats.total_output_tokens - self.total_embedding_input_tokens += stats.total_embedding_input_tokens + self.input_text_tokens += stats.input_text_tokens + self.input_audio_tokens += stats.input_audio_tokens + self.input_image_tokens += stats.input_image_tokens + self.cache_read_tokens += stats.cache_read_tokens + self.cache_creation_tokens += stats.cache_creation_tokens + self.output_text_tokens += stats.output_text_tokens + self.embedding_input_tokens += stats.embedding_input_tokens self.record_op_stats_lst.extend(stats.record_op_stats_lst) elif isinstance(stats, RecordOpStats): @@ -340,9 +306,13 @@ def __iadd__(self, stats: OperatorStats | RecordOpStats) -> OperatorStats: self.record_op_stats_lst.append(stats) self.total_op_time += stats.time_per_record self.total_op_cost += stats.cost_per_record - self.total_input_tokens += stats.total_input_tokens - self.total_output_tokens += stats.total_output_tokens - self.total_embedding_input_tokens += stats.total_embedding_input_tokens + self.input_text_tokens += stats.input_text_tokens + self.input_audio_tokens += stats.input_audio_tokens + self.input_image_tokens += stats.input_image_tokens + self.cache_read_tokens += stats.cache_read_tokens + self.cache_creation_tokens += stats.cache_creation_tokens + self.output_text_tokens += stats.output_text_tokens + self.embedding_input_tokens += stats.embedding_input_tokens else: raise TypeError(f"Cannot add {type(stats)} to OperatorStats") @@ -376,7 +346,7 @@ class BasePlanStats(BaseModel): # dictionary whose values are OperatorStats objects; # PlanStats maps {full_op_id -> OperatorStats} # SentinelPlanStats maps {logical_op_id -> {full_op_id -> OperatorStats}} - operator_stats: dict = Field(default_factory=dict) + operator_stats: dict[str, OperatorStats | dict[str, OperatorStats]] = Field(default_factory=dict) # dictionary whose values are GenerationStats objects for validation; # only used by SentinelPlanStats @@ -388,14 +358,26 @@ class BasePlanStats(BaseModel): # total cost for plan total_plan_cost: float = 0.0 - # total input tokens processed by this plan - total_input_tokens: int = 0 + # input text tokens processed by this plan + input_text_tokens: float = 0.0 + + # input audio tokens processed by this plan + input_audio_tokens: float = 0.0 + + # input image tokens processed by this plan + input_image_tokens: float = 0.0 + + # cache read tokens processed by this plan + cache_read_tokens: float = 0.0 + + # tokens written to cache by this plan + cache_creation_tokens: float = 0.0 - # total output tokens processed by this plan - total_output_tokens: int = 0 + # output text tokens generated by this plan + output_text_tokens: float = 0.0 - # total embedding input tokens processed by this plan - total_embedding_input_tokens: int = 0 + # embedding input tokens processed by this plan + embedding_input_tokens: float = 0.0 # start time for the plan execution; should be set by calling PlanStats.start() start_time: float | None = None @@ -409,10 +391,14 @@ def finish(self) -> None: if self.start_time is None: raise RuntimeError("PlanStats.start() must be called before PlanStats.finish()") self.total_plan_time = time.time() - self.start_time - self.total_plan_cost = self.sum_op_costs() + self.sum_validation_costs() - self.total_input_tokens = self.sum_input_tokens() + self.sum_validation_input_tokens() - self.total_output_tokens = self.sum_output_tokens() + self.sum_validation_output_tokens() - self.total_embedding_input_tokens = self.sum_embedding_input_tokens() + self.sum_validation_embedding_input_tokens() + self.total_plan_cost = self.sum_op_stats_field("total_op_cost") + self.sum_validation_stats_field("cost_per_record") + self.input_text_tokens = self.sum_op_stats_field("input_text_tokens") + self.sum_validation_stats_field("input_text_tokens") + self.input_audio_tokens = self.sum_op_stats_field("input_audio_tokens") + self.sum_validation_stats_field("input_audio_tokens") + self.input_image_tokens = self.sum_op_stats_field("input_image_tokens") + self.sum_validation_stats_field("input_image_tokens") + self.cache_read_tokens = self.sum_op_stats_field("cache_read_tokens") + self.sum_validation_stats_field("cache_read_tokens") + self.cache_creation_tokens = self.sum_op_stats_field("cache_creation_tokens") + self.sum_validation_stats_field("cache_creation_tokens") + self.output_text_tokens = self.sum_op_stats_field("output_text_tokens") + self.sum_validation_stats_field("output_text_tokens") + self.embedding_input_tokens = self.sum_op_stats_field("embedding_input_tokens") + self.sum_validation_stats_field("embedding_input_tokens") @staticmethod @abstractmethod @@ -423,32 +409,13 @@ def from_plan(plan) -> BasePlanStats: pass @abstractmethod - def sum_op_costs(self) -> float: - """ - Sum the costs of all operators in this plan. - """ + def sum_op_stats_field(self, field_name: str) -> float | int: + """Sum a given field across all operator stats in this plan.""" pass - @abstractmethod - def sum_input_tokens(self) -> int: - """ - Sum the input tokens processed by all operators in this plan. - """ - pass - - @abstractmethod - def sum_output_tokens(self) -> int: - """ - Sum the output tokens processed by all operators in this plan. - """ - pass - - @abstractmethod - def sum_embedding_input_tokens(self) -> int: - """ - Sum the input embedding tokens processed by all operators in this plan. - """ - pass + def sum_validation_stats_field(self, field_name: str) -> float | int: + """Sum a given field across all validation generation stats in this plan.""" + return sum([getattr(gen_stats, field_name) for _, gen_stats in self.validation_gen_stats.items()]) @abstractmethod def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None: @@ -471,35 +438,11 @@ def __str__(self) -> str: """ pass - def sum_validation_costs(self) -> float: - """ - Sum the costs of all validation generations in this plan. - """ - return sum([gen_stats.cost_per_record for _, gen_stats in self.validation_gen_stats.items()]) - - def sum_validation_input_tokens(self) -> int: - """ - Sum the input tokens processed by all validation generations in this plan. - """ - return sum([gen_stats.total_input_tokens for _, gen_stats in self.validation_gen_stats.items()]) - - def sum_validation_output_tokens(self) -> int: - """ - Sum the output tokens processed by all validation generations in this plan. - """ - return sum([gen_stats.total_output_tokens for _, gen_stats in self.validation_gen_stats.items()]) - - def sum_validation_embedding_input_tokens(self) -> int: - """ - Sum the input embedding tokens processed by all validation generations in this plan. - """ - return sum([gen_stats.total_embedding_input_tokens for _, gen_stats in self.validation_gen_stats.items()]) - def get_total_cost_so_far(self) -> float: """ Get the total cost incurred so far in this plan execution. """ - return self.sum_op_costs() + self.sum_validation_costs() + return self.sum_op_stats_field("total_op_cost") + self.sum_validation_stats_field("cost_per_record") class PlanStats(BasePlanStats): @@ -525,29 +468,9 @@ def from_plan(plan) -> PlanStats: return PlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats) - def sum_op_costs(self) -> float: - """ - Sum the costs of all operators in this plan. - """ - return sum([op_stats.total_op_cost for _, op_stats in self.operator_stats.items()]) - - def sum_input_tokens(self) -> int: - """ - Sum the input tokens processed by all operators in this plan. - """ - return sum([op_stats.total_input_tokens for _, op_stats in self.operator_stats.items()]) - - def sum_output_tokens(self) -> int: - """ - Sum the output tokens processed by all operators in this plan. - """ - return sum([op_stats.total_output_tokens for _, op_stats in self.operator_stats.items()]) - - def sum_embedding_input_tokens(self) -> int: - """ - Sum the input embedding tokens processed by all operators in this plan. - """ - return sum([op_stats.total_embedding_input_tokens for _, op_stats in self.operator_stats.items()]) + def sum_op_stats_field(self, field_name: str) -> float | int: + """Sum a given field across all operator stats in this plan.""" + return sum([getattr(op_stats, field_name) for _, op_stats in self.operator_stats.items()]) def add_record_op_stats(self, unique_full_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None: """ @@ -573,9 +496,13 @@ def __iadd__(self, plan_stats: PlanStats) -> None: """ self.total_plan_time += plan_stats.total_plan_time self.total_plan_cost += plan_stats.total_plan_cost - self.total_input_tokens += plan_stats.total_input_tokens - self.total_output_tokens += plan_stats.total_output_tokens - self.total_embedding_input_tokens += plan_stats.total_embedding_input_tokens + self.input_text_tokens += plan_stats.input_text_tokens + self.input_audio_tokens += plan_stats.input_audio_tokens + self.input_image_tokens += plan_stats.input_image_tokens + self.cache_read_tokens += plan_stats.cache_read_tokens + self.cache_creation_tokens += plan_stats.cache_creation_tokens + self.output_text_tokens += plan_stats.output_text_tokens + self.embedding_input_tokens += plan_stats.embedding_input_tokens for unique_full_op_id, op_stats in plan_stats.operator_stats.items(): if unique_full_op_id in self.operator_stats: self.operator_stats[unique_full_op_id] += op_stats @@ -585,9 +512,13 @@ def __iadd__(self, plan_stats: PlanStats) -> None: def __str__(self) -> str: stats = f"total_plan_time={self.total_plan_time} \n" stats += f"total_plan_cost={self.total_plan_cost} \n" - stats += f"total_input_tokens={self.total_input_tokens} \n" - stats += f"total_output_tokens={self.total_output_tokens} \n" - stats += f"total_embedding_input_tokens={self.total_embedding_input_tokens} \n" + stats += f"input_text_tokens={self.input_text_tokens} \n" + stats += f"input_audio_tokens={self.input_audio_tokens} \n" + stats += f"input_image_tokens={self.input_image_tokens} \n" + stats += f"cache_read_tokens={self.cache_read_tokens} \n" + stats += f"cache_creation_tokens={self.cache_creation_tokens} \n" + stats += f"output_text_tokens={self.output_text_tokens} \n" + stats += f"embedding_input_tokens={self.embedding_input_tokens} \n" for idx, op_stats in enumerate(self.operator_stats.values()): stats += f"{idx}. {op_stats.op_name} time={op_stats.total_op_time} cost={op_stats.total_op_cost} \n" return stats @@ -618,29 +549,9 @@ def from_plan(plan) -> SentinelPlanStats: return SentinelPlanStats(plan_id=plan.plan_id, plan_str=str(plan), operator_stats=operator_stats) - def sum_op_costs(self) -> float: - """ - Sum the costs of all operators in this plan. - """ - return sum(sum([op_stats.total_op_cost for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items()) - - def sum_input_tokens(self) -> int: - """ - Sum the input tokens processed by all operators in this plan. - """ - return sum(sum([op_stats.total_input_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items()) - - def sum_output_tokens(self) -> int: - """ - Sum the output tokens processed by all operators in this plan. - """ - return sum(sum([op_stats.total_output_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items()) - - def sum_embedding_input_tokens(self) -> int: - """ - Sum the output tokens processed by all operators in this plan. - """ - return sum(sum([op_stats.total_embedding_input_tokens for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items()) + def sum_op_stats_field(self, field_name: str) -> float | int: + """Sum a given field across all operator stats in this plan.""" + return sum(sum([getattr(op_stats, field_name) for _, op_stats in phys_op_stats.items()]) for _, phys_op_stats in self.operator_stats.items()) def add_record_op_stats(self, unique_logical_op_id: str, record_op_stats: RecordOpStats | list[RecordOpStats]) -> None: """ @@ -669,7 +580,6 @@ def add_validation_gen_stats(self, unique_logical_op_id: str, gen_stats: Generat else: self.validation_gen_stats[unique_logical_op_id] = gen_stats - def __iadd__(self, plan_stats: SentinelPlanStats) -> None: """ NOTE: we assume the execution layer guarantees: @@ -680,9 +590,13 @@ def __iadd__(self, plan_stats: SentinelPlanStats) -> None: """ self.total_plan_time += plan_stats.total_plan_time self.total_plan_cost += plan_stats.total_plan_cost - self.total_input_tokens += plan_stats.total_input_tokens - self.total_output_tokens += plan_stats.total_output_tokens - self.total_embedding_input_tokens += plan_stats.total_embedding_input_tokens + self.input_text_tokens += plan_stats.input_text_tokens + self.input_audio_tokens += plan_stats.input_audio_tokens + self.input_image_tokens += plan_stats.input_image_tokens + self.cache_read_tokens += plan_stats.cache_read_tokens + self.cache_creation_tokens += plan_stats.cache_creation_tokens + self.output_text_tokens += plan_stats.output_text_tokens + self.embedding_input_tokens += plan_stats.embedding_input_tokens for unique_logical_op_id, physical_op_stats in plan_stats.operator_stats.items(): for full_op_id, op_stats in physical_op_stats.items(): if unique_logical_op_id in self.operator_stats: @@ -702,9 +616,13 @@ def __iadd__(self, plan_stats: SentinelPlanStats) -> None: def __str__(self) -> str: stats = f"total_plan_time={self.total_plan_time} \n" stats += f"total_plan_cost={self.total_plan_cost} \n" - stats += f"total_input_tokens={self.total_input_tokens} \n" - stats += f"total_output_tokens={self.total_output_tokens} \n" - stats += f"total_embedding_input_tokens={self.total_embedding_input_tokens} \n" + stats += f"input_text_tokens={self.input_text_tokens} \n" + stats += f"input_audio_tokens={self.input_audio_tokens} \n" + stats += f"input_image_tokens={self.input_image_tokens} \n" + stats += f"cache_read_tokens={self.cache_read_tokens} \n" + stats += f"cache_creation_tokens={self.cache_creation_tokens} \n" + stats += f"output_text_tokens={self.output_text_tokens} \n" + stats += f"embedding_input_tokens={self.embedding_input_tokens} \n" for outer_idx, physical_op_stats in enumerate(self.operator_stats.values()): total_time = sum([op_stats.total_op_time for op_stats in physical_op_stats.values()]) total_cost = sum([op_stats.total_op_cost for op_stats in physical_op_stats.values()]) @@ -746,17 +664,26 @@ class ExecutionStats(BaseModel): # total cost for the entire execution total_execution_cost: float = 0.0 - # total number of input tokens processed - total_input_tokens: int = 0 + # input text tokens processed + input_text_tokens: float = 0.0 - # total number of output tokens processed - total_output_tokens: int = 0 + # input audio tokens processed + input_audio_tokens: float = 0.0 - # total number of embedding input tokens processed - total_embedding_input_tokens: int = 0 + # input image tokens processed + input_image_tokens: float = 0.0 - # total number of tokens processed - total_tokens: int = 0 + # cache read tokens processed + cache_read_tokens: float = 0.0 + + # tokens written to cache + cache_creation_tokens: float = 0.0 + + # output text tokens generated + output_text_tokens: float = 0.0 + + # embedding input tokens processed + embedding_input_tokens: float = 0.0 # dictionary of sentinel plan strings; useful for printing executed sentinel plans in demos sentinel_plan_strs: dict[str, str] = Field(default_factory=dict) @@ -806,50 +733,39 @@ def finish(self) -> None: self.total_execution_cost = self.optimization_cost + self.plan_execution_cost # compute the tokens for total execution - self.total_input_tokens = self.sum_input_tokens() - self.total_output_tokens = self.sum_output_tokens() - self.total_embedding_input_tokens = self.sum_embedding_input_tokens() - self.total_tokens = self.total_input_tokens + self.total_output_tokens + self.total_embedding_input_tokens + self.input_text_tokens = self.sum_plan_stats_field("input_text_tokens") + self.input_audio_tokens = self.sum_plan_stats_field("input_audio_tokens") + self.input_image_tokens = self.sum_plan_stats_field("input_image_tokens") + self.cache_read_tokens = self.sum_plan_stats_field("cache_read_tokens") + self.cache_creation_tokens = self.sum_plan_stats_field("cache_creation_tokens") + self.output_text_tokens = self.sum_plan_stats_field("output_text_tokens") + self.embedding_input_tokens = self.sum_plan_stats_field("embedding_input_tokens") # compute plan_strs self.plan_strs = {plan_id: plan_stats.plan_str for plan_id, plan_stats in self.plan_stats.items()} + def sum_plan_stats_field(self, field_name: str) -> float | int: + """ + Sum a given field across all PlanStats in this execution. + """ + sentinel_plan_field_sum = sum([plan_stats.sum_op_stats_field(field_name) + plan_stats.sum_validation_stats_field(field_name) for _, plan_stats in self.sentinel_plan_stats.items()]) + plan_field_sum = sum([plan_stats.sum_op_stats_field(field_name) for _, plan_stats in self.plan_stats.items()]) + return plan_field_sum + sentinel_plan_field_sum + def sum_sentinel_plan_costs(self) -> float: """ Sum the costs of all SentinelPlans in this execution. """ - return sum([plan_stats.sum_op_costs() + plan_stats.sum_validation_costs() for _, plan_stats in self.sentinel_plan_stats.items()]) + return sum([ + plan_stats.sum_op_stats_field("total_op_cost") + plan_stats.sum_validation_stats_field("cost_per_record") + for _, plan_stats in self.sentinel_plan_stats.items() + ]) def sum_plan_costs(self) -> float: """ Sum the costs of all PhysicalPlans in this execution. """ - return sum([plan_stats.sum_op_costs() for _, plan_stats in self.plan_stats.items()]) - - def sum_input_tokens(self) -> int: - """ - Sum the input tokens processed in this execution - """ - sentinel_plan_input_tokens = sum([plan_stats.sum_input_tokens() for _, plan_stats in self.sentinel_plan_stats.items()]) - plan_input_tokens = sum([plan_stats.sum_input_tokens() for _, plan_stats in self.plan_stats.items()]) - return plan_input_tokens + sentinel_plan_input_tokens - - def sum_output_tokens(self) -> int: - """ - Sum the output tokens processed in this execution - """ - sentinel_plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.sentinel_plan_stats.items()]) - plan_output_tokens = sum([plan_stats.sum_output_tokens() for _, plan_stats in self.plan_stats.items()]) - return plan_output_tokens + sentinel_plan_output_tokens - - - def sum_embedding_input_tokens(self) -> int: - """ - Sum the embedding input tokens processed in this execution - """ - sentinel_plan_embedding_input_tokens = sum([plan_stats.sum_embedding_input_tokens() for _, plan_stats in self.sentinel_plan_stats.items()]) - plan_embedding_input_tokens = sum([plan_stats.sum_embedding_input_tokens() for _, plan_stats in self.plan_stats.items()]) - return plan_embedding_input_tokens + sentinel_plan_embedding_input_tokens + return sum([plan_stats.sum_op_stats_field("total_op_cost") for _, plan_stats in self.plan_stats.items()]) def add_plan_stats(self, plan_stats: PlanStats | SentinelPlanStats | list[PlanStats] | list[SentinelPlanStats]) -> None: """ diff --git a/src/palimpzest/prompts/__init__.py b/src/palimpzest/prompts/__init__.py index 122ebaf50..6989d4880 100644 --- a/src/palimpzest/prompts/__init__.py +++ b/src/palimpzest/prompts/__init__.py @@ -10,6 +10,7 @@ ) from palimpzest.prompts.context_search import CONTEXT_SEARCH_PROMPT from palimpzest.prompts.prompt_factory import PromptFactory +from palimpzest.prompts.prompt_manager import PromptManager from palimpzest.prompts.utils import ( ONE_TO_MANY_OUTPUT_FORMAT_INSTRUCTION, ONE_TO_ONE_OUTPUT_FORMAT_INSTRUCTION, @@ -34,6 +35,8 @@ "FINAL_ANSWER_PRE_MESSAGES_PROMPT", # context search "CONTEXT_SEARCH_PROMPT", + # prompt cache + "PromptManager", # prompt factory "PromptFactory", # utils diff --git a/src/palimpzest/prompts/convert_prompts.py b/src/palimpzest/prompts/convert_prompts.py index 48127ca7f..e11b3b76b 100644 --- a/src/palimpzest/prompts/convert_prompts.py +++ b/src/palimpzest/prompts/convert_prompts.py @@ -62,7 +62,7 @@ OUTPUT FIELDS: {output_fields_desc} -CONTEXT: +<>CONTEXT: {context}<> Let's think step-by-step in order to answer the question. @@ -81,7 +81,7 @@ OUTPUT FIELDS: {output_fields_desc} -CONTEXT: +<>CONTEXT: {context}<> ANSWER: """ diff --git a/src/palimpzest/prompts/filter_prompts.py b/src/palimpzest/prompts/filter_prompts.py index 046093250..c06902388 100644 --- a/src/palimpzest/prompts/filter_prompts.py +++ b/src/palimpzest/prompts/filter_prompts.py @@ -11,11 +11,11 @@ INPUT FIELDS: {example_input_fields} +FILTER CONDITION: {example_filter_condition} + CONTEXT: {{{example_context}}}{image_disclaimer}{audio_disclaimer} -FILTER CONDITION: {example_filter_condition} - Let's think step-by-step in order to answer the question. REASONING: {example_reasoning} @@ -34,11 +34,11 @@ INPUT FIELDS: {example_input_fields} +FILTER CONDITION: {example_filter_condition} + CONTEXT: {{{example_context}}}{image_disclaimer}{audio_disclaimer} -FILTER CONDITION: {example_filter_condition} - ANSWER: TRUE --- """ @@ -51,11 +51,11 @@ INPUT FIELDS: {input_fields_desc} -CONTEXT: -{context}<> - FILTER CONDITION: {filter_condition} +<>CONTEXT: +{context}<> + Let's think step-by-step in order to answer the question. REASONING: """ @@ -68,9 +68,9 @@ INPUT FIELDS: {input_fields_desc} -CONTEXT: -{context}<> - FILTER CONDITION: {filter_condition} +<>CONTEXT: +{context}<> + ANSWER: """ diff --git a/src/palimpzest/prompts/join_prompts.py b/src/palimpzest/prompts/join_prompts.py index 71c44c041..ffcff4138 100644 --- a/src/palimpzest/prompts/join_prompts.py +++ b/src/palimpzest/prompts/join_prompts.py @@ -11,17 +11,17 @@ LEFT INPUT FIELDS: {example_input_fields} -LEFT CONTEXT: -{{{example_context}}}{image_disclaimer}{audio_disclaimer} - RIGHT INPUT FIELDS: {right_example_input_fields} +JOIN CONDITION: {example_join_condition} + +LEFT CONTEXT: +{{{example_context}}}{image_disclaimer}{audio_disclaimer} + RIGHT CONTEXT: {{{right_example_context}}}{right_image_disclaimer}{right_audio_disclaimer} -JOIN CONDITION: {example_join_condition} - Let's think step-by-step in order to evaluate the join condition. REASONING: {example_reasoning} @@ -40,17 +40,17 @@ LEFT INPUT FIELDS: {example_input_fields} -LEFT CONTEXT: -{{{example_context}}}{image_disclaimer}{audio_disclaimer} - RIGHT INPUT FIELDS: {right_example_input_fields} +JOIN CONDITION: {example_join_condition} + +LEFT CONTEXT: +{{{example_context}}}{image_disclaimer}{audio_disclaimer} + RIGHT CONTEXT: {{{right_example_context}}}{right_image_disclaimer}{right_audio_disclaimer} -JOIN CONDITION: {example_join_condition} - ANSWER: TRUE --- """ @@ -63,17 +63,17 @@ LEFT INPUT FIELDS: {input_fields_desc} -LEFT CONTEXT: -{context}<> - RIGHT INPUT FIELDS: {right_input_fields_desc} +JOIN CONDITION: {join_condition} + +<>LEFT CONTEXT: +{context}<> + RIGHT CONTEXT: {right_context}<> -JOIN CONDITION: {join_condition} - Let's think step-by-step in order to evaluate the join condition. REASONING: """ @@ -86,15 +86,15 @@ LEFT INPUT FIELDS: {input_fields_desc} -LEFT CONTEXT: -{context}<> - RIGHT INPUT FIELDS: {right_input_fields_desc} +JOIN CONDITION: {join_condition} + +<>LEFT CONTEXT: +{context}<> + RIGHT CONTEXT: {right_context}<> -JOIN CONDITION: {join_condition} - ANSWER: """ diff --git a/src/palimpzest/prompts/moa_aggregator_prompts.py b/src/palimpzest/prompts/moa_aggregator_prompts.py index e89b8642b..c051f9d2a 100644 --- a/src/palimpzest/prompts/moa_aggregator_prompts.py +++ b/src/palimpzest/prompts/moa_aggregator_prompts.py @@ -74,14 +74,14 @@ {output_format_instruction} Finish your response with a newline character followed by --- --- -{model_responses} - INPUT FIELDS: {input_fields_desc} OUTPUT FIELDS: {output_fields_desc} +<>{model_responses} + Let's think step-by-step in order to answer the question. REASONING: """ @@ -94,13 +94,13 @@ Remember, your answer must be TRUE or FALSE. Finish your response with a newline character followed by --- --- -{model_responses} - INPUT FIELDS: {input_fields_desc} FILTER CONDITION: {filter_condition} +<>{model_responses} + Let's think step-by-step in order to answer the question. REASONING: """ diff --git a/src/palimpzest/prompts/moa_proposer_prompts.py b/src/palimpzest/prompts/moa_proposer_prompts.py index 69e362b05..242087d15 100644 --- a/src/palimpzest/prompts/moa_proposer_prompts.py +++ b/src/palimpzest/prompts/moa_proposer_prompts.py @@ -35,11 +35,11 @@ INPUT FIELDS: {example_input_fields} +FILTER CONDITION: {example_filter_condition} + CONTEXT: {{{example_context}}}{image_disclaimer}{audio_disclaimer} -FILTER CONDITION: {example_filter_condition} - Let's think step-by-step in order to answer the question. ANSWER: {example_answer} @@ -59,7 +59,7 @@ OUTPUT FIELDS: {output_fields_desc} -CONTEXT: +<>CONTEXT: {context}<> Let's think step-by-step in order to answer the question. @@ -77,11 +77,11 @@ INPUT FIELDS: {input_fields_desc} -CONTEXT: -{context}<> - FILTER CONDITION: {filter_condition} +<>CONTEXT: +{context}<> + Let's think step-by-step in order to answer the question. ANSWER: """ diff --git a/src/palimpzest/prompts/prompt_factory.py b/src/palimpzest/prompts/prompt_factory.py index 443d17d39..7cb1204ed 100644 --- a/src/palimpzest/prompts/prompt_factory.py +++ b/src/palimpzest/prompts/prompt_factory.py @@ -2,6 +2,7 @@ import base64 import json +import os from typing import Any from pydantic import BaseModel @@ -140,6 +141,32 @@ ) +def _detect_image_media_type(filepath: str | None = None, base64_data: str | None = None) -> str: + """Detect image media type from file extension or base64 magic bytes.""" + if filepath: + ext = os.path.splitext(filepath)[1].lower() + ext_map = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", + ".gif": "image/gif", ".webp": "image/webp"} + if ext in ext_map: + return ext_map[ext] + + if base64_data: + try: + header = base64.b64decode(base64_data[:32]) + if header[:8] == b"\x89PNG\r\n\x1a\n": + return "image/png" + if header[:3] == b"\xff\xd8\xff": + return "image/jpeg" + if header[:6] in (b"GIF87a", b"GIF89a"): + return "image/gif" + if header[:4] == b"RIFF" and header[8:12] == b"WEBP": + return "image/webp" + except Exception: + pass + + return "image/jpeg" + + class PromptFactory: """Factory class for generating prompts for the Generator given the input(s).""" @@ -889,8 +916,9 @@ def _create_image_messages(self, candidate: DataRecord | list[DataRecord], input if field_type.annotation in [ImageFilepath, ImageFilepath | None, ImageFilepath | Any] and field_value is not None: with open(field_value, "rb") as f: base64_image_str = base64.b64encode(f.read()).decode("utf-8") + media_type = _detect_image_media_type(filepath=field_value, base64_data=base64_image_str) image_content.append( - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}} + {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{base64_image_str}"}} ) elif field_type.annotation in [list[ImageFilepath], list[ImageFilepath] | None, list[ImageFilepath] | Any]: @@ -899,8 +927,9 @@ def _create_image_messages(self, candidate: DataRecord | list[DataRecord], input continue with open(image_filepath, "rb") as f: base64_image_str = base64.b64encode(f.read()).decode("utf-8") + media_type = _detect_image_media_type(filepath=image_filepath, base64_data=base64_image_str) image_content.append( - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image_str}"}} + {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{base64_image_str}"}} ) # image url (or list of image urls) @@ -915,16 +944,18 @@ def _create_image_messages(self, candidate: DataRecord | list[DataRecord], input # pre-encoded images (or list of pre-encoded images) elif field_type.annotation in [ImageBase64, ImageBase64 | None, ImageBase64 | Any] and field_value is not None: + media_type = _detect_image_media_type(base64_data=field_value) image_content.append( - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{field_value}"}} + {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{field_value}"}} ) elif field_type.annotation in [list[ImageBase64], list[ImageBase64] | None, list[ImageBase64] | Any]: for base64_image in field_value: if base64_image is None: continue + media_type = _detect_image_media_type(base64_data=base64_image) image_content.append( - {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} + {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{base64_image}"}} ) return [{"role": "user", "type": "image", "content": image_content}] if len(image_content) > 0 else [] diff --git a/src/palimpzest/prompts/prompt_manager.py b/src/palimpzest/prompts/prompt_manager.py new file mode 100644 index 000000000..9b1d660d8 --- /dev/null +++ b/src/palimpzest/prompts/prompt_manager.py @@ -0,0 +1,230 @@ +""" +Prompt caching utility for different LLM providers. + +This module provides provider-specific prompt caching configurations: +- OpenAI: Automatic prefix caching with prompt_cache_key for sticky routing +- Gemini (Google AI Studio / Vertex AI): Implicit caching (automatic prefix matching) +- Anthropic: Explicit cache_control with ephemeral type on system and user message content +""" + +import copy +import uuid +from typing import Any + +from palimpzest.constants import Model + + +class PromptManager: + """ + Manages prompt caching configurations and message transformations for LLM providers. + + This class handles: + 1. Session-level state (e.g., OpenAI cache keys). + 2. Provider-specific request arguments (headers, extra_body). + 3. Transformation of messages for providers requiring explicit markers (Anthropic). + 4. Normalization of usage statistics. + """ + + CACHE_BOUNDARY_MARKER = "<>" + + def __init__(self, model: Model): + self.model = model + # Instance-level state ensures thread safety if we use one manager per plan/execution + self.openai_cache_key = f"pz-cache-{uuid.uuid4().hex[:12]}" if (self.model.is_provider_openai() or self.model.is_provider_azure()) else None + + def get_cache_kwargs(self) -> dict[str, Any]: + """ + Get provider-specific cache configuration kwargs for litellm.completion(). + + Returns: + A dictionary of kwargs to pass to litellm.completion() for enabling caching + """ + if not self.model.supports_prompt_caching(): + return {} + # OpenAI and Azure OpenAI: https://platform.openai.com/docs/guides/prompt-caching + # Use prompt_cache_key for sticky routing to the same cache shard + if self.model.is_provider_openai() or self.model.is_provider_azure(): + return {"extra_body": {"prompt_cache_key": self.openai_cache_key}} + else: + return {} + + def inject_cache_isolation_id(self, messages: list[dict], session_id: str) -> list[dict]: + """ + Inject a cache isolation ID into messages for testing cache behavior per-modality. + + This must happen BEFORE update_messages_for_caching so the ID becomes part of cached content. + """ + for msg in messages: + role = msg.get("role") + content = msg.get("content") + if role == "system" and isinstance(content, str) or \ + role == "user" and self.model.is_provider_anthropic() and msg.get("type") == "text" and isinstance(content, str): + msg["content"] = f"[{session_id}] " + content + return messages + + def update_messages_for_caching(self, messages: list[dict]) -> list[dict]: + """ + Transform messages to conform to provider-specific caching requirements. + + - Anthropic: Adds explicit cache_control markers. + - Others: Removes the generic cache boundary markers. + + Returns: + The transformed messages list. + """ + if not self.model.supports_prompt_caching(): + return messages + + # Anthropic: Explicit cache_control with ephemeral type + # https://platform.claude.com/docs/en/build-with-claude/prompt-caching + if self.model.is_provider_anthropic(): + return self._transform_messages_for_anthropic(messages) + # implicit caching for Gemini/OpenAI/Azure models that currently support caching + # OpenAI: https://platform.openai.com/docs/guides/prompt-caching + # Gemini: https://ai.google.dev/gemini-api/docs/caching + elif (self.model.is_provider_openai() or self.model.is_provider_azure() or + self.model.is_provider_google_ai_studio() or self.model.is_provider_vertex_ai()): + return self._remove_cache_boundary_markers(messages) + + return messages + + + def extract_usage_stats(self, usage: dict, is_audio_op: bool) -> dict[str, int]: + """ + Normalize cache statistics from provider-specific response formats. + """ + stats = { + "input_text_tokens": 0, + "input_image_tokens": 0, # forward looking + "input_audio_tokens": 0, + "cache_creation_tokens": 0, + "cache_read_tokens": 0 + } + + details = usage.get("prompt_tokens_details") or {} + + if self.model.is_provider_openai() or self.model.is_provider_azure(): + # only realtime audio models do, but they are not supported by PZ + if self.model.supports_prompt_caching() and not is_audio_op: + stats["cache_read_tokens"] = details.get("cached_tokens") or 0 + stats["input_text_tokens"] = (usage.get("prompt_tokens") or 0) - stats["cache_read_tokens"] + # audio models don't support caching for now + elif is_audio_op: + stats["input_text_tokens"] = details.get("text_tokens") or 0 + stats["input_audio_tokens"] = details.get("audio_tokens") or 0 + else: + stats["input_text_tokens"] = usage.get("prompt_tokens") or 0 + + # Moved to Gemini client class, now usage stats are extracted directly in GeminiClient + # elif self.model.is_provider_vertex_ai(): + # stats["cache_read_tokens"] = usage.get("cache_read_input_tokens") or 0 + # if stats["cache_read_tokens"] == 0: + # stats["cache_read_tokens"] = details.get("cached_tokens") or 0 + # stats["input_text_tokens"] = details.get("text_tokens") or 0 + # stats["input_audio_tokens"] = details.get("audio_tokens") or 0 + # stats["input_image_tokens"] = details.get("image_tokens") or 0 + + elif self.model.is_provider_anthropic(): + stats["cache_creation_tokens"] = usage.get("cache_creation_input_tokens") or 0 + stats["cache_read_tokens"] = usage.get("cache_read_input_tokens") or 0 + stats["input_text_tokens"] = max(0, (usage.get("prompt_tokens") or 0) - stats["cache_read_tokens"] - stats["cache_creation_tokens"]) + + elif self.model.is_vllm_model(): + # vLLM does not seem to provide cache statistics through litellm, so we currently have no way + # to extract cache read/creation tokens for vLLM models. + pass + + # all other models (assume caching not supported) + else: + if is_audio_op: + stats["input_text_tokens"] = details.get("text_tokens") or 0 + stats["input_audio_tokens"] = details.get("audio_tokens") or 0 + else: + stats["input_text_tokens"] = usage.get("prompt_tokens") or 0 + + + return stats + + + def _remove_cache_boundary_markers(self, messages: list[dict]) -> list[dict]: + """ + Remove <> markers from user messages. + + For providers with automatic (implicit) caching (OpenAI, Gemini), we don't need + explicit cache markers. This function cleans up the markers from prompts. + + Args: + messages: The list of messages to transform. + + Returns: + A new list of messages with cache boundary markers removed. + """ + result = [] + for message in messages: + new_message = message.copy() + if new_message.get("role") == "user": + content = new_message.get("content", "") + if isinstance(content, str) and self.CACHE_BOUNDARY_MARKER in content: + new_message["content"] = content.replace(self.CACHE_BOUNDARY_MARKER, "") + result.append(new_message) + return result + + + def _transform_messages_for_anthropic(self, messages: list[dict]) -> list[dict]: + """ + Add cache_control markers to system messages and user prompt prefixes for Anthropic models. + + This transforms messages to: + 1. Add cache_control to system message content blocks + 2. Convert user messages with <> marker into a single message with multiple content blocks: + a. Static prefix block (with cache_control) - cacheable across records + b. Dynamic content block (without cache_control) - changes per record + + Args: + messages: The list of messages to transform. + + Returns: + A new list of messages with cache_control markers added. + """ + result = [] + for message in messages: + new_message = copy.deepcopy(message) + role = new_message.get("role") + content = new_message.get("content", "") + + # 1. Handle System Messages + if role == "system": + if isinstance(content, str) and content: + new_message["content"] = [{ + "type": "text", + "text": content, + "cache_control": {"type": "ephemeral"} + }] + elif isinstance(content, list) and content: + # Apply to last block if it's text + last_block = new_message["content"][-1] + if isinstance(last_block, dict) and last_block.get("type") == "text": + last_block["cache_control"] = {"type": "ephemeral"} + + # 2. Handle User Messages (The Split Logic) + elif role == "user" and isinstance(content, str) and self.CACHE_BOUNDARY_MARKER in content: + static, dynamic = content.split(self.CACHE_BOUNDARY_MARKER, 1) + + new_blocks = [] + if static.strip(): + new_blocks.append({ + "type": "text", + "text": static, + "cache_control": {"type": "ephemeral"} + }) + + if dynamic.strip(): + new_blocks.append({"type": "text", "text": dynamic}) + + if new_blocks: + new_message["content"] = new_blocks + else: + new_message["content"] = "" + + result.append(new_message) + return result diff --git a/src/palimpzest/prompts/split_merge_prompts.py b/src/palimpzest/prompts/split_merge_prompts.py index 34dc51d66..61074d1b4 100644 --- a/src/palimpzest/prompts/split_merge_prompts.py +++ b/src/palimpzest/prompts/split_merge_prompts.py @@ -73,14 +73,14 @@ {output_format_instruction} Finish your response with a newline character followed by --- --- -{chunk_outputs} - INPUT FIELDS: {input_fields_desc} OUTPUT FIELDS: {output_fields_desc} +<>{chunk_outputs} + Let's think step-by-step in order to answer the question. REASONING: """ @@ -93,13 +93,13 @@ Remember, your answer must be TRUE or FALSE. Finish your response with a newline character followed by --- --- -{chunk_outputs} - INPUT FIELDS: {input_fields_desc} FILTER CONDITION: {filter_condition} +<>{chunk_outputs} + Let's think step-by-step in order to answer the question. REASONING: """ diff --git a/src/palimpzest/prompts/split_proposer_prompts.py b/src/palimpzest/prompts/split_proposer_prompts.py index 20c5d7f52..cd6ba3547 100644 --- a/src/palimpzest/prompts/split_proposer_prompts.py +++ b/src/palimpzest/prompts/split_proposer_prompts.py @@ -35,11 +35,11 @@ INPUT FIELDS: {example_input_fields} +FILTER CONDITION: {example_filter_condition} + CONTEXT: {{{example_context}}}{image_disclaimer}{audio_disclaimer} -FILTER CONDITION: {example_filter_condition} - Let's think step-by-step in order to answer the question. ANSWER: {example_answer} @@ -59,7 +59,7 @@ OUTPUT FIELDS: {output_fields_desc} -CONTEXT: +<>CONTEXT: {context}<> Let's think step-by-step in order to answer the question. @@ -77,11 +77,11 @@ INPUT FIELDS: {input_fields_desc} -CONTEXT: -{context}<> - FILTER CONDITION: {filter_condition} +<>CONTEXT: +{context}<> + Let's think step-by-step in order to answer the question. ANSWER: """ diff --git a/src/palimpzest/query/execution/mab_execution_strategy.py b/src/palimpzest/query/execution/mab_execution_strategy.py index f32aadb3d..1af1da41a 100644 --- a/src/palimpzest/query/execution/mab_execution_strategy.py +++ b/src/palimpzest/query/execution/mab_execution_strategy.py @@ -350,6 +350,8 @@ def remove_unavailable_root_datasets(source_indices: str | tuple) -> str | tuple right_source_indices = source_indices[1] left_inputs = left_source_indices_to_inputs.get(left_source_indices, []) right_inputs = right_source_indices_to_inputs.get(right_source_indices, []) + left_inputs = [input for input in left_inputs if input is not None] + right_inputs = [input for input in right_inputs if input is not None] if len(left_inputs) > 0 and len(right_inputs) > 0: op_inputs.append((op, (left_source_indices, right_source_indices), (left_inputs, right_inputs))) return op_inputs diff --git a/src/palimpzest/query/generators/__init__.py b/src/palimpzest/query/generators/__init__.py index e69de29bb..f17887548 100644 --- a/src/palimpzest/query/generators/__init__.py +++ b/src/palimpzest/query/generators/__init__.py @@ -0,0 +1,4 @@ +from palimpzest.query.generators.gemini_client import GeminiClient, GeminiResponse +from palimpzest.query.generators.generators import Generator + +__all__ = ["Generator", "GeminiClient", "GeminiResponse"] diff --git a/src/palimpzest/query/generators/gemini_client.py b/src/palimpzest/query/generators/gemini_client.py new file mode 100644 index 000000000..a31e23b89 --- /dev/null +++ b/src/palimpzest/query/generators/gemini_client.py @@ -0,0 +1,297 @@ +""" +Direct client for Gemini (Google AI Studio and Vertex AI) that bypasses litellm. + +This module provides a GeminiClient class that: +1. Calls Gemini API directly via google-genai SDK +2. Converts litellm/palimpzest message format to Gemini format +3. Relies on implicit context caching (automatic prefix matching) +""" + +from __future__ import annotations + +import base64 +import logging +from dataclasses import dataclass +from typing import Any + +from google import genai +from google.genai import types + +logger = logging.getLogger(__name__) + + +@dataclass +class GeminiResponse: + """Response object that mimics litellm completion response structure.""" + content: str + usage: dict + raw_response: Any = None + + +class GeminiClient: + """ + Direct client for Gemini (Google AI Studio and Vertex AI) that bypasses litellm. + Uses implicit caching (automatic prefix matching) for prompt caching. + + Uses a singleton pattern per (model, use_vertex) so that client state is shared + across all Generator instances using the same model and provider. + + Args: + model: Model name (e.g., "gemini-2.5-flash") + use_vertex: If True, use Vertex AI; otherwise use Google AI Studio + """ + + _instances: dict[tuple[str, bool], GeminiClient] = {} + + # Maps reasoning_effort to Gemini thinking_budget token counts + # Reference: https://github.com/BerriAI/litellm/blob/620664921902d7a9bfb29897a7b27c1a7ef4ddfb/litellm/constants.py#L88 + REASONING_EFFORT_TO_THINKING_BUDGET = { + "disable": 0, + "minimal": 128, + "low": 1024, + "medium": 2048, + "high": 4096, + } + + @classmethod + def get_instance(cls, model: str, use_vertex: bool = False) -> GeminiClient: + """Get or create a singleton GeminiClient for the given model and provider.""" + key = (model, use_vertex) + if key not in cls._instances: + cls._instances[key] = cls(model, use_vertex) + return cls._instances[key] + + def __init__(self, model: str, use_vertex: bool = False): + self.model = model + self.use_vertex = use_vertex + # Vertex AI: uses GOOGLE_APPLICATION_CREDENTIALS for auth + self.client = genai.Client(vertexai=True) if use_vertex else genai.Client() + + def _detect_image_media_type(self, base64_data: str) -> str: + """Detect image format from base64 data by examining magic bytes.""" + try: + header = base64.b64decode(base64_data[:32]) + if header[:8] == b"\x89PNG\r\n\x1a\n": + return "image/png" + if header[:3] == b"\xff\xd8\xff": + return "image/jpeg" + if header[:6] in (b"GIF87a", b"GIF89a"): + return "image/gif" + if header[:4] == b"RIFF" and header[8:12] == b"WEBP": + return "image/webp" + except Exception: + pass + return "image/jpeg" + + def _transform_messages(self, messages: list[dict]) -> tuple[str | None, list[dict]]: + """ + Transform litellm/palimpzest message format to Gemini API format. + + Args: + messages: List of messages in litellm/palimpzest format + + Returns: + Tuple of (system_instruction, gemini_contents) + """ + gemini_contents = [] + system_instruction = None + + for msg in messages: + role = msg.get("role") + msg_type = msg.get("type") + content = msg.get("content") + + if role == "system": + if isinstance(content, list): + text_parts = [ + block.get("text", "") + for block in content + if block.get("type") == "text" + ] + system_instruction = "".join(text_parts) + else: + system_instruction = content + + elif role == "user": + parts = [] + + if msg_type == "text" or msg_type is None: + if isinstance(content, list): + for block in content: + if block.get("type") == "text": + parts.append({"text": block.get("text", "")}) + elif isinstance(content, str): + parts.append({"text": content}) + + elif msg_type == "image": + for img in content: + if img.get("type") == "image_url": + url = img["image_url"]["url"] + if url.startswith("data:"): + # Robust parsing: handle "data:[];base64," + base64_marker = ";base64," + marker_idx = url.find(base64_marker) + if marker_idx == -1: + continue + data = url[marker_idx + len(base64_marker):] + media_type = self._detect_image_media_type(data) + parts.append({ + "inline_data": { + "mime_type": media_type, + "data": data, + } + }) + + elif msg_type == "input_audio": + for audio in content: + if audio.get("type") == "input_audio": + audio_data = audio["input_audio"] + parts.append({ + "inline_data": { + "mime_type": f"audio/{audio_data.get('format', 'wav')}", + "data": audio_data["data"], + } + }) + + if parts: + # Merge consecutive user messages + if gemini_contents and gemini_contents[-1]["role"] == "user": + gemini_contents[-1]["parts"].extend(parts) + else: + gemini_contents.append({"role": "user", "parts": parts}) + + elif role == "assistant": + # Convert assistant to model role + parts = [] + if isinstance(content, str): + parts.append({"text": content}) + elif isinstance(content, list): + for block in content: + if block.get("type") == "text": + parts.append({"text": block.get("text", "")}) + + if parts: + # Merge consecutive model messages (Gemini requires strict role alternation) + if gemini_contents and gemini_contents[-1]["role"] == "model": + gemini_contents[-1]["parts"].extend(parts) + else: + gemini_contents.append({"role": "model", "parts": parts}) + + return system_instruction, gemini_contents + + def _extract_usage_stats(self, usage_metadata: Any) -> dict: + """ + Extract and process usage statistics from Gemini response into the + standard format expected by Generator. + + Args: + usage_metadata: The usage_metadata from Gemini response + + Returns: + Dictionary with information needed by GenerationStats. + """ + generation_stats = { + "input_text_tokens": 0, + "input_image_tokens": 0, + "input_audio_tokens": 0, + "cache_read_tokens": 0, + "text_cache_read_tokens": 0, + "image_cache_read_tokens": 0, + "audio_cache_read_tokens": 0, + "cache_creation_tokens": 0, + "output_text_tokens": 0 + } + + if usage_metadata is None: + return generation_stats + + try: + raw = usage_metadata.model_dump() + except (AttributeError, Exception): + # Fallback for SDK versions without model_dump() + raw = vars(usage_metadata) if hasattr(usage_metadata, "__dict__") else {} + logger.warning("Could not call model_dump() on usage_metadata, using fallback") + + # Parse cache read tokens by modality + for detail in (raw.get("cache_tokens_details") or []): + modality = (detail.get("modality") or "").upper() + token_count = detail.get("token_count") or 0 + if modality == "TEXT": + generation_stats["text_cache_read_tokens"] = token_count + elif modality == "IMAGE": + generation_stats["image_cache_read_tokens"] = token_count + elif modality == "AUDIO": + generation_stats["audio_cache_read_tokens"] = token_count + + generation_stats["cache_read_tokens"] = raw.get("cached_content_token_count") or 0 + + # Parse input tokens by modality (excludes cached tokens) + for detail in (raw.get("prompt_tokens_details") or []): + modality = (detail.get("modality") or "").upper() + token_count = detail.get("token_count") or 0 + if modality == "TEXT": + generation_stats["input_text_tokens"] = max(0, token_count - generation_stats["text_cache_read_tokens"]) + elif modality == "IMAGE": + generation_stats["input_image_tokens"] = max(0, token_count - generation_stats["image_cache_read_tokens"]) + elif modality == "AUDIO": + generation_stats["input_audio_tokens"] = max(0, token_count - generation_stats["audio_cache_read_tokens"]) + + generation_stats["output_text_tokens"] = (raw.get("candidates_token_count") or 0) + (raw.get("thoughts_token_count") or 0) + + return generation_stats + + def generate( + self, + messages: list[dict], + temperature: float = 0.0, + reasoning_effort: str | None = None, + ) -> GeminiResponse: + """ + Generate content using Gemini API directly. + + Args: + messages: List of messages in litellm/palimpzest format + temperature: Sampling temperature (default: 0.0) + reasoning_effort: Optional thinking budget level ("low", "medium", "high") + + Returns: + GeminiResponse with content, usage stats, and raw response + """ + system_instruction, gemini_contents = self._transform_messages(messages) + + # Build config + config_kwargs = {"temperature": temperature} + if system_instruction: + config_kwargs["system_instruction"] = system_instruction + + # Map reasoning_effort to thinking_config + if reasoning_effort is not None: + budget = self.REASONING_EFFORT_TO_THINKING_BUDGET.get(reasoning_effort) + if budget is None: + raise ValueError(f"Invalid reasoning effort: {reasoning_effort}") + config_kwargs["thinking_config"] = types.ThinkingConfig(thinking_budget=budget) + + response = self.client.models.generate_content( + model=self.model, + contents=gemini_contents, + config=types.GenerateContentConfig(**config_kwargs), + ) + + # Extract response content + content = "" + if response.candidates and response.candidates[0].content: + parts = response.candidates[0].content.parts + if parts: + content = "".join( + part.text for part in parts + if hasattr(part, "text") and part.text + ) + + # Extract and process usage stats + usage = self._extract_usage_stats(response.usage_metadata) + + return GeminiResponse( + content=content, + usage=usage, + raw_response=response, + ) diff --git a/src/palimpzest/query/generators/generators.py b/src/palimpzest/query/generators/generators.py index f03add679..34867f862 100644 --- a/src/palimpzest/query/generators/generators.py +++ b/src/palimpzest/query/generators/generators.py @@ -17,15 +17,10 @@ from colorama import Fore, Style from pydantic.fields import FieldInfo -from palimpzest.constants import ( - MODEL_CARDS, - Cardinality, - Model, - PromptStrategy, -) +from palimpzest.constants import Cardinality, Model, PromptStrategy from palimpzest.core.elements.records import DataRecord from palimpzest.core.models import GenerationStats -from palimpzest.prompts import PromptFactory +from palimpzest.prompts import PromptFactory, PromptManager from palimpzest.utils.model_helpers import resolve_reasoning_effort # DEFINITIONS @@ -110,7 +105,6 @@ def __init__( model: Model, prompt_strategy: PromptStrategy, reasoning_effort: str, - api_base: str | None = None, cardinality: Cardinality = Cardinality.ONE_TO_ONE, desc: str | None = None, verbose: bool = False, @@ -120,10 +114,19 @@ def __init__( self.cardinality = cardinality self.prompt_strategy = prompt_strategy self.reasoning_effort = reasoning_effort - self.api_base = api_base self.desc = desc self.verbose = verbose self.prompt_factory = PromptFactory(prompt_strategy, model, cardinality, desc) + self.prompt_manager = PromptManager(model) + + # Initialize GeminiClient for direct Gemini API calls (Google AI Studio and Vertex AI) + self.gemini_client = None + if model.is_model_gemini(): + from palimpzest.query.generators.gemini_client import GeminiClient + self.gemini_client = GeminiClient.get_instance( + model=model.get_model_name(), + use_vertex=model.is_provider_vertex_ai(), + ) def _parse_reasoning(self, completion_text: str, **kwargs) -> str: """Extract the reasoning for the generated output from the completion object.""" @@ -316,23 +319,52 @@ def __call__(self, candidate: DataRecord | list[DataRecord], fields: dict[str, F messages = self.prompt_factory.create_messages(candidate, fields, right_candidate, **kwargs) is_audio_op = any(msg.get("type") == "input_audio" for msg in messages) + if "cache_isolation_id" in kwargs: + messages = self.prompt_manager.inject_cache_isolation_id(messages, kwargs["cache_isolation_id"]) + # generate the text completion start_time = time.time() completion = None + completion_text = None try: - completion_kwargs = {} - if not self.model.is_o_model() and not self.model.is_gpt_5_model(): - completion_kwargs = {"temperature": kwargs.get("temperature", 0.0), **completion_kwargs} - if is_audio_op: - completion_kwargs = {"modalities": ["text"], **completion_kwargs} - if self.model.is_reasoning_model(): - reasoning_effort = resolve_reasoning_effort(self.model, self.reasoning_effort) - completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs} - if self.model.is_vllm_model(): - completion_kwargs = {"api_base": self.api_base, "api_key": os.environ.get("VLLM_API_KEY", "fake-api-key"), **completion_kwargs} - completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs) - end_time = time.time() - logger.debug(f"Generated completion in {end_time - start_time:.2f} seconds") + # added for testing purpose, may be removed if needed + if "generating_messages_only" in kwargs and kwargs["generating_messages_only"]: + return messages + + messages = self.prompt_manager.update_messages_for_caching(messages) + + # Use GeminiClient directly for Google AI Studio models + if self.gemini_client is not None: + gemini_response = self.gemini_client.generate( + messages=messages, + temperature=kwargs.get("temperature", 0.0), + reasoning_effort=resolve_reasoning_effort(self.model, self.reasoning_effort) if self.model.is_reasoning_model() else None, + ) + end_time = time.time() + completion_text = gemini_response.content + usage_stats = gemini_response.usage + logger.debug(f"Generated completion via GeminiClient in {end_time - start_time:.2f} seconds") + else: + # Use litellm for all other providers + completion_kwargs = {} + if not self.model.is_o_model() and not self.model.is_gpt_5_model(): + completion_kwargs = {"temperature": kwargs.get("temperature", 0.0), **completion_kwargs} + if is_audio_op: + completion_kwargs = {"modalities": ["text"], **completion_kwargs} + if self.model.is_reasoning_model(): + reasoning_effort = resolve_reasoning_effort(self.model, self.reasoning_effort) + completion_kwargs = {"reasoning_effort": reasoning_effort, **completion_kwargs} + if self.model.is_vllm_model(): + completion_kwargs = {"api_base": self.model.api_base, "api_key": os.environ.get("VLLM_API_KEY", "fake-api-key"), **self.model.vllm_kwargs, **completion_kwargs} + + cache_kwargs = self.prompt_manager.get_cache_kwargs() + completion_kwargs = {**completion_kwargs, **cache_kwargs} + completion = litellm.completion(model=self.model_name, messages=messages, **completion_kwargs) + end_time = time.time() + completion_text = completion.choices[0].message.content + usage = completion.usage.model_dump() + logger.debug(f"Generated completion via litellm in {end_time - start_time:.2f} seconds") + # if there's an error generating the completion, we have to return an empty answer # and can only account for the time spent performing the failed generation except Exception as e: @@ -353,60 +385,97 @@ def __call__(self, candidate: DataRecord | list[DataRecord], fields: dict[str, F # parse usage statistics and create the GenerationStats generation_stats = None - if completion is not None: - usage = completion.usage.model_dump() - + if completion_text is not None: # get cost per input/output token for the model - usd_per_input_token = MODEL_CARDS[self.model_name].get("usd_per_input_token", 0.0) - usd_per_audio_input_token = MODEL_CARDS[self.model_name].get("usd_per_audio_input_token", 0.0) - usd_per_output_token = MODEL_CARDS[self.model_name]["usd_per_output_token"] - - # TODO: for some models (e.g. GPT-5) we cannot separate text from image prompt tokens yet; - # for now, we only use tokens from prompt_token_details if it's an audio prompt - # get output tokens (all text) and input tokens by modality - output_tokens = usage["completion_tokens"] - if is_audio_op: - input_audio_tokens = usage["prompt_tokens_details"].get("audio_tokens", 0) - input_text_tokens = usage["prompt_tokens_details"].get("text_tokens", 0) - input_image_tokens = 0 + usd_per_input_token = self.model.get_usd_per_input_token() or 0.0 + usd_per_audio_input_token = self.model.get_usd_per_audio_input_token() or 0.0 + usd_per_image_input_token = self.model.get_usd_per_image_input_token() or 0.0 + usd_per_output_token = self.model.get_usd_per_output_token() or 0.0 + usd_per_cache_read_token = self.model.get_usd_per_cache_read_token() or 0.0 + usd_per_audio_cache_read_token = self.model.get_usd_per_audio_cache_read_token() or 0.0 + usd_per_image_cache_read_token = self.model.get_usd_per_image_cache_read_token() or 0.0 + usd_per_cache_creation_token = self.model.get_usd_per_cache_creation_token() or 0.0 + + # Extract usage stats based on provider + if self.gemini_client is not None: + # Usage already processed by GeminiClient + output_text_tokens = usage_stats.get("output_text_tokens", 0) else: - input_audio_tokens = 0 - input_text_tokens = usage["prompt_tokens"] - input_image_tokens = 0 - input_tokens = input_audio_tokens + input_text_tokens + input_image_tokens - - # compute the input and output token costs - total_input_cost = (input_text_tokens + input_image_tokens) * usd_per_input_token + input_audio_tokens * usd_per_audio_input_token - total_output_cost = output_tokens * usd_per_output_token + # litellm response format + output_text_tokens = usage.get("completion_tokens") or 0 + usage_stats = self.prompt_manager.extract_usage_stats(usage, is_audio_op) + + input_text_tokens = usage_stats["input_text_tokens"] + input_audio_tokens = usage_stats["input_audio_tokens"] + input_image_tokens = usage_stats["input_image_tokens"] + cache_read_tokens = usage_stats["cache_read_tokens"] + cache_creation_tokens = usage_stats["cache_creation_tokens"] + + # Compute cache cost: use per-modality breakdown if available (Gemini), otherwise aggregate + if self.gemini_client is not None: + cache_cost = ( + usage_stats["text_cache_read_tokens"] * usd_per_cache_read_token + + usage_stats["audio_cache_read_tokens"] * usd_per_audio_cache_read_token + + usage_stats["image_cache_read_tokens"] * usd_per_image_cache_read_token + ) + else: + cache_cost = ( + cache_read_tokens * usd_per_cache_read_token + + cache_creation_tokens * usd_per_cache_creation_token + ) + + total_cost = ( + input_text_tokens * usd_per_input_token + + input_audio_tokens * usd_per_audio_input_token + + input_image_tokens * usd_per_image_input_token + + cache_cost + + output_text_tokens * usd_per_output_token + ) generation_stats = GenerationStats( model_name=self.model_name, llm_call_duration_secs=end_time - start_time, fn_call_duration_secs=0.0, - input_audio_tokens=input_audio_tokens, + # Raw token counts by modality input_text_tokens=input_text_tokens, + input_audio_tokens=input_audio_tokens, input_image_tokens=input_image_tokens, - total_input_tokens=input_tokens, - total_output_tokens=output_tokens, - total_input_cost=total_input_cost, - total_output_cost=total_output_cost, - cost_per_record=total_input_cost + total_output_cost, + output_text_tokens=output_text_tokens, + # Cache token counts + cache_read_tokens=cache_read_tokens, + cache_creation_tokens=cache_creation_tokens, + # Cost + cost_per_record=total_cost, total_llm_calls=1, ) # pretty print prompt + full completion output for debugging - completion_text = completion.choices[0].message.content prompt, system_prompt = "", "" for message in messages: if message["role"] == "system": - system_prompt += message["content"] + "\n" + content = message["content"] + if isinstance(content, str): + system_prompt += content + "\n" + elif isinstance(content, list): + # Handle Anthropic-style content blocks + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + system_prompt += block.get("text", "") + "\n" if message["role"] == "user": - if message["type"] == "text": - prompt += message["content"] + "\n" - elif message["type"] == "image": - prompt += "\n" * len(message["content"]) - elif message["type"] == "input_audio": - prompt += "