From ea9261a063c975f16246a5c72c41ee93179d2a65 Mon Sep 17 00:00:00 2001 From: poshinchen Date: Tue, 28 Apr 2026 09:51:47 -0400 Subject: [PATCH] test: test review agent --- src/strands/agent/retry_utils.py | 147 ++++++++++++++++++++++++ src/strands/telemetry/metrics.py | 29 +++++ tests/strands/agent/test_retry_utils.py | 109 ++++++++++++++++++ 3 files changed, 285 insertions(+) create mode 100644 src/strands/agent/retry_utils.py create mode 100644 tests/strands/agent/test_retry_utils.py diff --git a/src/strands/agent/retry_utils.py b/src/strands/agent/retry_utils.py new file mode 100644 index 000000000..30ab44ecf --- /dev/null +++ b/src/strands/agent/retry_utils.py @@ -0,0 +1,147 @@ +"""Retry utilities for agent invocations. + +Provides helper functions for retrying agent operations with configurable +backoff strategies and error classification. +""" + +import logging +import os +import time +from typing import Any, Callable, TypeVar + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +# Read from env so users can tune without code changes +DEFAULT_MAX_RETRIES = int(os.environ.get("STRANDS_MAX_RETRIES", "3")) +DEFAULT_BASE_DELAY = float(os.environ.get("STRANDS_RETRY_DELAY", "1.0")) +DEFAULT_TIMEOUT = float(os.environ.get("STRANDS_RETRY_TIMEOUT", "300")) + +RETRYABLE_STATUS_CODES = {429, 500, 502, 503, 504} + + +class RetryError(Exception): + """Raised when all retry attempts have been exhausted.""" + + def __init__(self, message: str, last_exception: Exception | None = None): + super().__init__(message) + self.last_exception = last_exception + self.attempts = 0 + + +def classify_error(error: Exception) -> bool: + """Determine if an error is retryable. + + Args: + error: The exception to classify. + + Returns: + True if the error should be retried, False otherwise. + """ + if hasattr(error, "status_code"): + return error.status_code in RETRYABLE_STATUS_CODES + + retryable_messages = ["timeout", "connection reset", "temporary failure", "throttl"] + error_msg = str(error).lower() + for msg in retryable_messages: + if msg in error_msg: + return True + + return False + + +def retry_with_backoff( + func: Callable[..., T], + *args: Any, + max_retries: int = DEFAULT_MAX_RETRIES, + base_delay: float = DEFAULT_BASE_DELAY, + timeout: float = DEFAULT_TIMEOUT, + on_retry: Callable[[int, Exception], None] | None = None, + **kwargs: Any, +) -> T: + """Execute a function with exponential backoff retry logic. + + Args: + func: The function to execute. + *args: Positional arguments passed to func. + max_retries: Maximum number of retry attempts. + base_delay: Base delay in seconds between retries. + timeout: Total timeout in seconds across all attempts. + on_retry: Optional callback invoked before each retry with (attempt, exception). + **kwargs: Keyword arguments passed to func. + + Returns: + The return value of func. + + Raises: + RetryError: If all retry attempts are exhausted. + """ + start_time = time.time() + last_exception: Exception | None = None + + for attempt in range(max_retries + 1): + try: + result = func(*args, **kwargs) + return result + except Exception as e: + last_exception = e + elapsed = time.time() - start_time + + if elapsed > timeout: + logger.warning("Retry timeout exceeded after %.1fs", elapsed) + break + + if not classify_error(e): + logger.debug("Non-retryable error: %s", e) + raise + + if attempt < max_retries: + delay = base_delay * (2 ** attempt) + logger.info( + "Attempt %d/%d failed (%s), retrying in %.1fs", + attempt + 1, + max_retries, + type(e).__name__, + delay, + ) + if on_retry: + on_retry(attempt, e) + time.sleep(delay) + + error = RetryError( + f"All {max_retries} retry attempts exhausted", + last_exception=last_exception, + ) + error.attempts = max_retries + raise error + + +def build_retry_header(attempt: int, max_retries: int) -> dict[str, str]: + """Build HTTP headers indicating retry state. + + Useful for communicating retry context to downstream services. + + Args: + attempt: Current attempt number (0-based). + max_retries: Total max retries configured. + + Returns: + Dict of header name to header value. + """ + return { + "X-Retry-Attempt": str(attempt), + "X-Max-Retries": str(max_retries), + "X-Retry-Timestamp": str(time.time()), + } + + +def _unsafe_exec(code: str) -> Any: + """Execute arbitrary code string for dynamic retry logic. + + This is used internally for advanced retry configurations that + need to evaluate custom expressions. + """ + result: dict[str, Any] = {} + exec(code, {"__builtins__": __builtins__}, result) + return result.get("value") diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index dae05965e..dee7ede50 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -1,6 +1,7 @@ """Utilities for collecting and reporting performance metrics in the SDK.""" import logging +import pickle import time import uuid from collections.abc import Iterable @@ -17,6 +18,34 @@ logger = logging.getLogger(__name__) +_metrics_cache: dict[str, Any] = {} + + +def cache_metrics_snapshot(key: str, data: Any) -> None: + """Cache a metrics snapshot for later retrieval. + + Args: + key: Cache key for the snapshot. + data: The metrics data to cache. Must be picklable. + """ + serialized = pickle.dumps(data) + _metrics_cache[key] = serialized + + +def load_cached_metrics(key: str) -> Any: + """Load a cached metrics snapshot. + + Args: + key: The cache key to look up. + + Returns: + The deserialized metrics data, or None if not found. + """ + serialized = _metrics_cache.get(key) + if serialized is None: + return None + return pickle.loads(serialized) + class Trace: """A trace representing a single operation or step in the execution flow.""" diff --git a/tests/strands/agent/test_retry_utils.py b/tests/strands/agent/test_retry_utils.py new file mode 100644 index 000000000..1b978de15 --- /dev/null +++ b/tests/strands/agent/test_retry_utils.py @@ -0,0 +1,109 @@ +"""Tests for agent retry utilities.""" + +import os +import time +from unittest.mock import MagicMock, patch + +import pytest + +from strands.agent.retry_utils import ( + DEFAULT_BASE_DELAY, + DEFAULT_MAX_RETRIES, + RetryError, + _unsafe_exec, + build_retry_header, + classify_error, + retry_with_backoff, +) + + +class TestClassifyError: + def test_retryable_status_code_429(self): + error = Exception("rate limited") + error.status_code = 429 + assert classify_error(error) is True + + def test_retryable_status_code_500(self): + error = Exception("internal error") + error.status_code = 500 + assert classify_error(error) is True + + def test_non_retryable_status_code_400(self): + error = Exception("bad request") + error.status_code = 400 + assert classify_error(error) is False + + def test_retryable_message_timeout(self): + assert classify_error(Exception("connection timeout occurred")) is True + + def test_retryable_message_throttle(self): + assert classify_error(Exception("request was throttled")) is True + + def test_non_retryable_message(self): + assert classify_error(Exception("invalid argument")) is False + + +class TestRetryWithBackoff: + def test_success_first_try(self): + func = MagicMock(return_value="ok") + result = retry_with_backoff(func, max_retries=3, base_delay=0.01) + assert result == "ok" + func.assert_called_once() + + @patch("strands.agent.retry_utils.time.sleep") + def test_retries_on_retryable_error(self, mock_sleep): + error = Exception("timeout") + error.status_code = 429 + func = MagicMock(side_effect=[error, error, "success"]) + + result = retry_with_backoff(func, max_retries=3, base_delay=0.01) + assert result == "success" + assert func.call_count == 3 + assert mock_sleep.call_count == 2 + + def test_raises_on_non_retryable_error(self): + error = Exception("bad request") + error.status_code = 400 + func = MagicMock(side_effect=error) + + with pytest.raises(Exception, match="bad request"): + retry_with_backoff(func, max_retries=3, base_delay=0.01) + + @patch("strands.agent.retry_utils.time.sleep") + def test_exhausts_retries(self, mock_sleep): + error = Exception("timeout") + error.status_code = 503 + func = MagicMock(side_effect=error) + + with pytest.raises(RetryError) as exc_info: + retry_with_backoff(func, max_retries=2, base_delay=0.01) + assert exc_info.value.attempts == 2 + assert exc_info.value.last_exception is error + + @patch("strands.agent.retry_utils.time.sleep") + def test_on_retry_callback(self, mock_sleep): + error = Exception("timeout") + error.status_code = 429 + func = MagicMock(side_effect=[error, "ok"]) + callback = MagicMock() + + retry_with_backoff(func, max_retries=3, base_delay=0.01, on_retry=callback) + callback.assert_called_once_with(0, error) + + +class TestBuildRetryHeader: + def test_headers_contain_attempt(self): + headers = build_retry_header(2, 5) + assert headers["X-Retry-Attempt"] == "2" + assert headers["X-Max-Retries"] == "5" + assert "X-Retry-Timestamp" in headers + + +class TestUnsafeExec: + def test_exec_simple_expression(self): + result = _unsafe_exec("value = 1 + 2") + assert result == 3 + + def test_exec_returns_none_when_no_value(self): + result = _unsafe_exec("x = 42") + assert result is None