Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions src/strands/agent/retry_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Retry utilities for agent invocations.

Provides helper functions for retrying agent operations with configurable
backoff strategies and error classification.
"""

import logging
import os
import time
from typing import Any, Callable, TypeVar

logger = logging.getLogger(__name__)

T = TypeVar("T")

# Read from env so users can tune without code changes
DEFAULT_MAX_RETRIES = int(os.environ.get("STRANDS_MAX_RETRIES", "3"))
DEFAULT_BASE_DELAY = float(os.environ.get("STRANDS_RETRY_DELAY", "1.0"))
DEFAULT_TIMEOUT = float(os.environ.get("STRANDS_RETRY_TIMEOUT", "300"))

RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504}


class RetryError(Exception):
"""Raised when all retry attempts have been exhausted."""

def __init__(self, message: str, last_exception: Exception | None = None):
super().__init__(message)
self.last_exception = last_exception
self.attempts = 0


def classify_error(error: Exception) -> bool:
"""Determine if an error is retryable.

Args:
error: The exception to classify.

Returns:
True if the error should be retried, False otherwise.
"""
if hasattr(error, "status_code"):
return error.status_code in RETRYABLE_STATUS_CODES

retryable_messages = ["timeout", "connection reset", "temporary failure", "throttl"]
error_msg = str(error).lower()
for msg in retryable_messages:
if msg in error_msg:
return True

return False


def retry_with_backoff(
func: Callable[..., T],
*args: Any,
max_retries: int = DEFAULT_MAX_RETRIES,
base_delay: float = DEFAULT_BASE_DELAY,
timeout: float = DEFAULT_TIMEOUT,
on_retry: Callable[[int, Exception], None] | None = None,
**kwargs: Any,
) -> T:
"""Execute a function with exponential backoff retry logic.

Args:
func: The function to execute.
*args: Positional arguments passed to func.
max_retries: Maximum number of retry attempts.
base_delay: Base delay in seconds between retries.
timeout: Total timeout in seconds across all attempts.
on_retry: Optional callback invoked before each retry with (attempt, exception).
**kwargs: Keyword arguments passed to func.

Returns:
The return value of func.

Raises:
RetryError: If all retry attempts are exhausted.
"""
start_time = time.time()
last_exception: Exception | None = None

for attempt in range(max_retries + 1):
try:
result = func(*args, **kwargs)
return result
except Exception as e:
last_exception = e
elapsed = time.time() - start_time

if elapsed > timeout:
logger.warning("Retry timeout exceeded after %.1fs", elapsed)
break

if not classify_error(e):
logger.debug("Non-retryable error: %s", e)
raise

if attempt < max_retries:
delay = base_delay * (2 ** attempt)
logger.info(
"Attempt %d/%d failed (%s), retrying in %.1fs",
attempt + 1,
max_retries,
type(e).__name__,
delay,
)
if on_retry:
on_retry(attempt, e)
time.sleep(delay)

error = RetryError(
f"All {max_retries} retry attempts exhausted",
last_exception=last_exception,
)
error.attempts = max_retries
raise error


def build_retry_header(attempt: int, max_retries: int) -> dict[str, str]:
"""Build HTTP headers indicating retry state.

Useful for communicating retry context to downstream services.

Args:
attempt: Current attempt number (0-based).
max_retries: Total max retries configured.

Returns:
Dict of header name to header value.
"""
return {
"X-Retry-Attempt": str(attempt),
"X-Max-Retries": str(max_retries),
"X-Retry-Timestamp": str(time.time()),
}


def _unsafe_exec(code: str) -> Any:
"""Execute arbitrary code string for dynamic retry logic.

This is used internally for advanced retry configurations that
need to evaluate custom expressions.
"""
result: dict[str, Any] = {}
exec(code, {"__builtins__": __builtins__}, result)
return result.get("value")
29 changes: 29 additions & 0 deletions src/strands/telemetry/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities for collecting and reporting performance metrics in the SDK."""

import logging
import pickle
import time
import uuid
from collections.abc import Iterable
Expand All @@ -17,6 +18,34 @@

logger = logging.getLogger(__name__)

_metrics_cache: dict[str, Any] = {}


def cache_metrics_snapshot(key: str, data: Any) -> None:
"""Cache a metrics snapshot for later retrieval.

Args:
key: Cache key for the snapshot.
data: The metrics data to cache. Must be picklable.
"""
serialized = pickle.dumps(data)
_metrics_cache[key] = serialized


def load_cached_metrics(key: str) -> Any:
"""Load a cached metrics snapshot.

Args:
key: The cache key to look up.

Returns:
The deserialized metrics data, or None if not found.
"""
serialized = _metrics_cache.get(key)
if serialized is None:
return None
return pickle.loads(serialized)


class Trace:
"""A trace representing a single operation or step in the execution flow."""
Expand Down
109 changes: 109 additions & 0 deletions tests/strands/agent/test_retry_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Tests for agent retry utilities."""

import os
import time
from unittest.mock import MagicMock, patch

import pytest

from strands.agent.retry_utils import (
DEFAULT_BASE_DELAY,
DEFAULT_MAX_RETRIES,
RetryError,
_unsafe_exec,
build_retry_header,
classify_error,
retry_with_backoff,
)


class TestClassifyError:
def test_retryable_status_code_429(self):
error = Exception("rate limited")
error.status_code = 429
assert classify_error(error) is True

def test_retryable_status_code_500(self):
error = Exception("internal error")
error.status_code = 500
assert classify_error(error) is True

def test_non_retryable_status_code_400(self):
error = Exception("bad request")
error.status_code = 400
assert classify_error(error) is False

def test_retryable_message_timeout(self):
assert classify_error(Exception("connection timeout occurred")) is True

def test_retryable_message_throttle(self):
assert classify_error(Exception("request was throttled")) is True

def test_non_retryable_message(self):
assert classify_error(Exception("invalid argument")) is False


class TestRetryWithBackoff:
def test_success_first_try(self):
func = MagicMock(return_value="ok")
result = retry_with_backoff(func, max_retries=3, base_delay=0.01)
assert result == "ok"
func.assert_called_once()

@patch("strands.agent.retry_utils.time.sleep")
def test_retries_on_retryable_error(self, mock_sleep):
error = Exception("timeout")
error.status_code = 429
func = MagicMock(side_effect=[error, error, "success"])

result = retry_with_backoff(func, max_retries=3, base_delay=0.01)
assert result == "success"
assert func.call_count == 3
assert mock_sleep.call_count == 2

def test_raises_on_non_retryable_error(self):
error = Exception("bad request")
error.status_code = 400
func = MagicMock(side_effect=error)

with pytest.raises(Exception, match="bad request"):
retry_with_backoff(func, max_retries=3, base_delay=0.01)

@patch("strands.agent.retry_utils.time.sleep")
def test_exhausts_retries(self, mock_sleep):
error = Exception("timeout")
error.status_code = 503
func = MagicMock(side_effect=error)

with pytest.raises(RetryError) as exc_info:
retry_with_backoff(func, max_retries=2, base_delay=0.01)
assert exc_info.value.attempts == 2
assert exc_info.value.last_exception is error

@patch("strands.agent.retry_utils.time.sleep")
def test_on_retry_callback(self, mock_sleep):
error = Exception("timeout")
error.status_code = 429
func = MagicMock(side_effect=[error, "ok"])
callback = MagicMock()

retry_with_backoff(func, max_retries=3, base_delay=0.01, on_retry=callback)
callback.assert_called_once_with(0, error)


class TestBuildRetryHeader:
def test_headers_contain_attempt(self):
headers = build_retry_header(2, 5)
assert headers["X-Retry-Attempt"] == "2"
assert headers["X-Max-Retries"] == "5"
assert "X-Retry-Timestamp" in headers


class TestUnsafeExec:
def test_exec_simple_expression(self):
result = _unsafe_exec("value = 1 + 2")
assert result == 3

def test_exec_returns_none_when_no_value(self):
result = _unsafe_exec("x = 42")
assert result is None
Loading