diff --git a/growthbook/__init__.py b/growthbook/__init__.py index 8424714..b5a5406 100644 --- a/growthbook/__init__.py +++ b/growthbook/__init__.py @@ -7,6 +7,12 @@ BackoffStrategy ) +from .cache_interfaces import ( + AbstractFeatureCache, + AbstractAsyncFeatureCache, + InMemoryAsyncFeatureCache +) + # Plugin support from .plugins import ( GrowthBookTrackingPlugin, diff --git a/growthbook/cache_interfaces.py b/growthbook/cache_interfaces.py new file mode 100644 index 0000000..46ab4aa --- /dev/null +++ b/growthbook/cache_interfaces.py @@ -0,0 +1,109 @@ +import asyncio +from abc import abstractmethod, ABC +from time import time +from typing import Optional, Dict + +class AbstractFeatureCache(ABC): + @abstractmethod + def get(self, key: str) -> Optional[Dict]: + pass + + @abstractmethod + def set(self, key: str, value: Dict, ttl: int) -> None: + pass + + def clear(self) -> None: + pass + +class AbstractAsyncFeatureCache(ABC): + """Abstract base class for async feature caching implementations""" + + @abstractmethod + async def get(self, key: str) -> Optional[Dict]: + """ + Retrieve cached features by key. + + Args: + key: Cache key + + Returns: + Cached dictionary or None if not found/expired + """ + pass + + @abstractmethod + async def set(self, key: str, value: Dict, ttl: int) -> None: + """ + Store features in cache with TTL. + + Args: + key: Cache key + value: Features dictionary to cache + ttl: Time to live in seconds + """ + pass + + async def clear(self) -> None: + """Clear all cached entries (optional to override)""" + pass + +class CacheEntry(object): + def __init__(self, value: Dict, ttl: int) -> None: + self.value = value + self.ttl = ttl + self.expires = time() + ttl + + def update(self, value: Dict): + self.value = value + self.expires = time() + self.ttl + + +class InMemoryFeatureCache(AbstractFeatureCache): + def __init__(self) -> None: + self.cache: Dict[str, CacheEntry] = {} + + def get(self, key: str) -> Optional[Dict]: + if key in self.cache: + entry = self.cache[key] + if entry.expires >= time(): + return entry.value + return None + + def set(self, key: str, value: Dict, ttl: int) -> None: + if key in self.cache: + self.cache[key].update(value) + else: + self.cache[key] = CacheEntry(value, ttl) + + def clear(self) -> None: + self.cache.clear() + + +class InMemoryAsyncFeatureCache(AbstractAsyncFeatureCache): + """ + Async in-memory cache implementation. + Uses the same CacheEntry structure but with async interface. + """ + + def __init__(self) -> None: + self._cache: Dict[str, CacheEntry] = {} + self._lock = asyncio.Lock() + + async def get(self, key: str) -> Optional[Dict]: + async with self._lock: + if key in self._cache: + entry = self._cache[key] + if entry.expires >= time(): + return entry.value + return None + + async def set(self, key: str, value: Dict, ttl: int) -> None: + async with self._lock: + if key in self._cache: + self._cache[key].update(value) + else: + self._cache[key] = CacheEntry(value, ttl) + + async def clear(self) -> None: + async with self._lock: + self._cache.clear() diff --git a/growthbook/common_types.py b/growthbook/common_types.py index c0a187b..2ad8efe 100644 --- a/growthbook/common_types.py +++ b/growthbook/common_types.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Union, Set, Tuple from enum import Enum from abc import ABC, abstractmethod +from .cache_interfaces import AbstractFeatureCache, AbstractAsyncFeatureCache class VariationMeta(TypedDict): key: str @@ -396,7 +397,7 @@ def get_all_assignments(self, attributes: Dict[str, str]) -> Dict[str, Dict]: return docs @dataclass -class StackContext: +class StackContext: id: Optional[str] = None evaluated_features: Set[str] = field(default_factory=set) @@ -433,6 +434,8 @@ class Options: tracking_plugins: Optional[List[Any]] = None http_connect_timeout: Optional[int] = None http_read_timeout: Optional[int] = None + cache: Optional[AbstractFeatureCache] = None + async_cache: Optional[AbstractAsyncFeatureCache] = None @dataclass diff --git a/growthbook/growthbook.py b/growthbook/growthbook.py index c89bb37..8fcc787 100644 --- a/growthbook/growthbook.py +++ b/growthbook/growthbook.py @@ -13,18 +13,18 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Set, Tuple, List, Dict, Callable - -from .common_types import ( EvaluationContext, - Experiment, - FeatureResult, - Feature, - GlobalContext, - Options, - Result, StackContext, - UserContext, - AbstractStickyBucketService, - FeatureRule -) +from collections import OrderedDict +from .cache_interfaces import AbstractFeatureCache, InMemoryFeatureCache +from .common_types import (EvaluationContext, + Experiment, + FeatureResult, + Feature, + GlobalContext, + Options, + Result, StackContext, + UserContext, + AbstractStickyBucketService + ) # Only require typing_extensions if using Python 3.7 or earlier if sys.version_info >= (3, 8): @@ -33,7 +33,6 @@ from typing_extensions import TypedDict from base64 import b64decode -from time import time import aiohttp import asyncio @@ -63,50 +62,6 @@ def decrypt(encrypted_str: str, key_str: str) -> str: return bytestring.decode("utf-8") -class AbstractFeatureCache(ABC): - @abstractmethod - def get(self, key: str) -> Optional[Dict]: - pass - - @abstractmethod - def set(self, key: str, value: Dict, ttl: int) -> None: - pass - - def clear(self) -> None: - pass - - -class CacheEntry(object): - def __init__(self, value: Dict, ttl: int) -> None: - self.value = value - self.ttl = ttl - self.expires = time() + ttl - - def update(self, value: Dict): - self.value = value - self.expires = time() + self.ttl - - -class InMemoryFeatureCache(AbstractFeatureCache): - def __init__(self) -> None: - self.cache: Dict[str, CacheEntry] = {} - - def get(self, key: str) -> Optional[Dict]: - if key in self.cache: - entry = self.cache[key] - if entry.expires >= time(): - return entry.value - return None - - def set(self, key: str, value: Dict, ttl: int) -> None: - if key in self.cache: - self.cache[key].update(value) - else: - self.cache[key] = CacheEntry(value, ttl) - - def clear(self) -> None: - self.cache.clear() - class InMemoryStickyBucketService(AbstractStickyBucketService): def __init__(self) -> None: self.docs: Dict[str, Dict] = {} @@ -159,7 +114,7 @@ def disconnect(self, timeout=10): """Gracefully disconnect with timeout""" logger.debug("Initiating SSE client disconnect") self.is_running = False - + if self._loop and self._loop.is_running(): future = asyncio.run_coroutine_threadsafe(self._stop_session(timeout), self._loop) try: @@ -190,12 +145,12 @@ def _get_sse_url(self, api_host: str, client_key: str) -> str: async def _init_session(self): url = self._get_sse_url(self.api_host, self.client_key) - + try: while self.is_running: try: - async with aiohttp.ClientSession(headers=self.headers, - timeout=aiohttp.ClientTimeout(connect=self.timeout)) as session: + async with aiohttp.ClientSession(headers=self.headers, + timeout=aiohttp.ClientTimeout(connect=self.timeout)) as session: self._sse_session = session async with session.get(url) as response: @@ -235,7 +190,7 @@ async def _process_response(self, response): if not self.is_running: logger.debug("SSE processing stopped - is_running is False") break - + decoded_line = line.decode('utf-8').strip() if decoded_line.startswith("event:"): event_data['type'] = decoded_line[len("event:"):].strip() @@ -248,7 +203,7 @@ async def _process_response(self, response): except Exception as e: logger.warning(f"Error in event handler: {e}") event_data = {} - + # Process any remaining event data if 'type' in event_data and 'data' in event_data: try: @@ -277,7 +232,7 @@ async def _close_session(self): def _run_sse_channel(self): self._loop = asyncio.new_event_loop() - + try: self._loop.run_until_complete(self._init_session()) except asyncio.CancelledError: @@ -289,7 +244,7 @@ def _run_sse_channel(self): async def _stop_session(self, timeout=10): """Stop the SSE session and cancel all tasks with timeout""" logger.debug("Stopping SSE session") - + # Close the session first if self._sse_session and not self._sse_session.closed: try: @@ -302,15 +257,15 @@ async def _stop_session(self, timeout=10): if self._loop and self._loop.is_running(): try: # Get all tasks for this specific loop - tasks = [task for task in asyncio.all_tasks(self._loop) - if not task.done() and task is not asyncio.current_task(self._loop)] - + tasks = [task for task in asyncio.all_tasks(self._loop) + if not task.done() and task is not asyncio.current_task(self._loop)] + if tasks: logger.debug(f"Cancelling {len(tasks)} SSE tasks") # Cancel all tasks for task in tasks: task.cancel() - + # Wait for tasks to complete with timeout try: await asyncio.wait_for( @@ -325,9 +280,6 @@ async def _stop_session(self, timeout=10): except Exception as e: logger.warning(f"Error during SSE task cleanup: {e}") -from collections import OrderedDict - -# ... (imports) class FeatureRepository(object): def __init__(self) -> None: @@ -337,12 +289,12 @@ def __init__(self) -> None: self.http_read_timeout: Optional[int] = None self.sse_client: Optional[SSEClient] = None self._feature_update_callbacks: List[Callable[[Dict], None]] = [] - + # Background refresh support self._refresh_thread: Optional[threading.Thread] = None self._refresh_stop_event = threading.Event() self._refresh_lock = threading.Lock() - + # ETag cache for bandwidth optimization # Using OrderedDict for LRU cache (max 100 entries) self._etag_cache: OrderedDict[str, Tuple[str, Dict[str, Any]]] = OrderedDict() @@ -382,7 +334,7 @@ def load_features( ) -> Optional[Dict]: if not client_key: raise ValueError("Must specify `client_key` to refresh features") - + key = api_host + "::" + client_key cached = self.cache.get(key) @@ -395,28 +347,33 @@ def load_features( self._notify_feature_update_callbacks(res) return res return cached - - + async def load_features_async( self, api_host: str, client_key: str, decryption_key: str = "", ttl: int = 600 ) -> Optional[Dict]: + if not client_key: + raise ValueError("Must specify `client_key` to refresh features") + key = api_host + "::" + client_key cached = self.cache.get(key) + if not cached: res = await self._fetch_features_async(api_host, client_key, decryption_key) if res is not None: + # save in cache self.cache.set(key, res, ttl) + logger.debug("Fetched features from API, stored in cache") # Notify callbacks about fresh features self._notify_feature_update_callbacks(res) return res return cached - + @property def user_agent_suffix(self) -> Optional[str]: return getattr(self, "_user_agent_suffix", None) - + @user_agent_suffix.setter def user_agent_suffix(self, value: Optional[str]) -> None: self._user_agent_suffix = value @@ -428,23 +385,23 @@ def _get(self, url: str, headers: Optional[Dict[str, str]] = None): timeout = Timeout(connect=self.http_connect_timeout, read=self.http_read_timeout) self.http = self.http or PoolManager(timeout=timeout) return self.http.request("GET", url, headers=headers or {}) - + def _get_headers(self, client_key: str, existing_headers: Dict[str, str] = None) -> Dict[str, str]: headers = existing_headers or {} headers['Accept-Encoding'] = "gzip, deflate" - + # Add User-Agent with optional suffix ua = "Gb-Python" ua += f"-{self.user_agent_suffix}" if self.user_agent_suffix else f"-{client_key[-4:]}" headers['User-Agent'] = ua - + return headers def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: url = self._get_features_url(api_host, client_key) headers = self._get_headers(client_key) logger.debug(f"Fetching features from {url} with headers {headers}") - + # Check if we have a cached ETag for this URL cached_etag = None cached_data = None @@ -457,10 +414,10 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: logger.debug(f"Using cached ETag for request: {cached_etag[:20]}...") else: logger.debug(f"No ETag cache found for URL: {url}") - + try: r = self._get(url, headers) - + # Handle 304 Not Modified - content hasn't changed if r.status == 304: logger.debug(f"ETag match! Server returned 304 Not Modified - using cached data (saved bandwidth)") @@ -470,15 +427,15 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: else: logger.warning("Received 304 but no cached data available") return None - + if r.status >= 400: logger.warning( "Failed to fetch features, received status code %d", r.status ) return None - + decoded = json.loads(r.data.decode("utf-8")) - + # Store the new ETag if present response_etag = r.headers.get('ETag') if response_etag: @@ -487,7 +444,7 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: # Enforce max size if len(self._etag_cache) > self._max_etag_entries: self._etag_cache.popitem(last=False) - + if cached_etag: logger.debug(f"ETag updated: {cached_etag[:20]}... -> {response_etag[:20]}...") else: @@ -495,17 +452,17 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: logger.debug(f"ETag cache now contains {len(self._etag_cache)} entries") else: logger.debug("No ETag header in response") - + return decoded # type: ignore[no-any-return] except Exception as e: logger.error(f"Failed to decode feature JSON from GrowthBook API: {e}") return None - + async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optional[Dict]: url = self._get_features_url(api_host, client_key) headers = self._get_headers(client_key=client_key) logger.debug(f"[Async] Fetching features from {url} with headers {headers}") - + # Check if we have a cached ETag for this URL cached_etag = None cached_data = None @@ -518,26 +475,27 @@ async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optio logger.debug(f"[Async] Using cached ETag for request: {cached_etag[:20]}...") else: logger.debug(f"[Async] No ETag cache found for URL: {url}") - + try: async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as response: # Handle 304 Not Modified - content hasn't changed if response.status == 304: - logger.debug(f"[Async] ETag match! Server returned 304 Not Modified - using cached data (saved bandwidth)") + logger.debug( + f"[Async] ETag match! Server returned 304 Not Modified - using cached data (saved bandwidth)") if cached_data is not None: logger.debug(f"[Async] Returning cached response ({len(str(cached_data))} bytes)") return cached_data else: logger.warning("[Async] Received 304 but no cached data available") return None - + if response.status >= 400: logger.warning("Failed to fetch features, received status code %d", response.status) return None - + decoded = await response.json() - + # Store the new ETag if present response_etag = response.headers.get('ETag') if response_etag: @@ -546,15 +504,16 @@ async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optio # Enforce max size if len(self._etag_cache) > self._max_etag_entries: self._etag_cache.popitem(last=False) - + if cached_etag: logger.debug(f"[Async] ETag updated: {cached_etag[:20]}... -> {response_etag[:20]}...") else: - logger.debug(f"[Async] New ETag cached: {response_etag[:20]}... ({len(str(decoded))} bytes)") + logger.debug( + f"[Async] New ETag cached: {response_etag[:20]}... ({len(str(decoded))} bytes)") logger.debug(f"[Async] ETag cache now contains {len(self._etag_cache)} entries") else: logger.debug("[Async] No ETag header in response") - + return decoded # type: ignore[no-any-return] except aiohttp.ClientError as e: logger.warning(f"HTTP request failed: {e}") @@ -562,7 +521,7 @@ async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optio except Exception as e: logger.error(f"Failed to decode feature JSON from GrowthBook API: {e}") return None - + def decrypt_response(self, data, decryption_key: str): if "encryptedFeatures" in data: if not decryption_key: @@ -578,7 +537,7 @@ def decrypt_response(self, data, decryption_key: str): return None elif "features" not in data: logger.warning("GrowthBook API response missing features") - + if "encryptedSavedGroups" in data: if not decryption_key: raise ValueError("Must specify decryption_key") @@ -591,7 +550,7 @@ def decrypt_response(self, data, decryption_key: str): logger.warning( "Failed to decrypt saved groups from GrowthBook API response" ) - + return data # Fetch features from the GrowthBook API @@ -605,7 +564,7 @@ def _fetch_features( data = self.decrypt_response(decoded, decryption_key) return data # type: ignore[no-any-return] - + async def _fetch_features_async( self, api_host: str, client_key: str, decryption_key: str = "" ) -> Optional[Dict]: @@ -617,11 +576,11 @@ async def _fetch_features_async( return data # type: ignore[no-any-return] - def startAutoRefresh(self, api_host, client_key, cb, streaming_timeout=30): if not client_key: raise ValueError("Must specify `client_key` to start features streaming") - self.sse_client = self.sse_client or SSEClient(api_host=api_host, client_key=client_key, on_event=cb, timeout=streaming_timeout) + self.sse_client = self.sse_client or SSEClient(api_host=api_host, client_key=client_key, on_event=cb, + timeout=streaming_timeout) self.sse_client.connect() def stopAutoRefresh(self, timeout=10): @@ -629,8 +588,9 @@ def stopAutoRefresh(self, timeout=10): if self.sse_client: self.sse_client.disconnect(timeout=timeout) self.sse_client = None - - def start_background_refresh(self, api_host: str, client_key: str, decryption_key: str, ttl: int = 600, refresh_interval: int = 300) -> None: + + def start_background_refresh(self, api_host: str, client_key: str, decryption_key: str, ttl: int = 600, + refresh_interval: int = 300) -> None: """Start periodic background refresh task""" if not client_key: @@ -639,7 +599,7 @@ def start_background_refresh(self, api_host: str, client_key: str, decryption_ke with self._refresh_lock: if self._refresh_thread is not None: return # Already running - + self._refresh_stop_event.clear() self._refresh_thread = threading.Thread( target=self._background_refresh_worker, @@ -648,15 +608,16 @@ def start_background_refresh(self, api_host: str, client_key: str, decryption_ke ) self._refresh_thread.start() logger.debug("Started background refresh task") - - def _background_refresh_worker(self, api_host: str, client_key: str, decryption_key: str, ttl: int, refresh_interval: int) -> None: + + def _background_refresh_worker(self, api_host: str, client_key: str, decryption_key: str, ttl: int, + refresh_interval: int) -> None: """Worker method for periodic background refresh""" while not self._refresh_stop_event.is_set(): try: # Wait for the refresh interval or stop event if self._refresh_stop_event.wait(refresh_interval): break # Stop event was set - + logger.debug("Background refresh for Features - started") res = self._fetch_features(api_host, client_key, decryption_key) if res is not None: @@ -669,11 +630,11 @@ def _background_refresh_worker(self, api_host: str, client_key: str, decryption_ logger.warning("Background refresh failed") except Exception as e: logger.warning(f"Background refresh error: {e}") - + def stop_background_refresh(self) -> None: """Stop background refresh task""" self._refresh_stop_event.set() - + with self._refresh_lock: if self._refresh_thread is not None: self._refresh_thread.join(timeout=1.0) # Wait up to 1 second @@ -752,7 +713,8 @@ def __init__( self._user = user self._groups = groups self._overrides = overrides - self._forcedVariations = (forced_variations if forced_variations is not None else forcedVariations) if forced_variations is not None or forcedVariations else {} + self._forcedVariations = ( + forced_variations if forced_variations is not None else forcedVariations) if forced_variations is not None or forcedVariations else {} self._tracked: Dict[str, Any] = {} self._assigned: Dict[str, Any] = {} @@ -777,7 +739,7 @@ def __init__( ), features={}, saved_groups=self._saved_groups - ) + ) # Create a user context for the current user self._user_ctx: UserContext = UserContext( url=self._url, @@ -805,7 +767,7 @@ def __init__( # Start background refresh task for stale-while-revalidate self.load_features() # Initial load feature_repo.start_background_refresh( - self._api_host, self._client_key, self._decryption_key, + self._api_host, self._client_key, self._decryption_key, self._cache_ttl, self._stale_ttl ) @@ -849,7 +811,7 @@ def _features_event_handler(self, features): decoded = json.loads(features) if not decoded: return None - + data = feature_repo.decrypt_response(decoded, self._decryption_key) if data is not None: @@ -867,13 +829,12 @@ def _dispatch_sse_event(self, event_data): elif event_type == 'features': self._features_event_handler(data) - def startAutoRefresh(self): if not self._client_key: raise ValueError("Must specify `client_key` to start features streaming") - + feature_repo.startAutoRefresh( - api_host=self._api_host, + api_host=self._api_host, client_key=self._client_key, cb=self._dispatch_sse_event, streaming_timeout=self._streaming_timeout @@ -938,34 +899,34 @@ def get_attributes(self) -> dict: def destroy(self, timeout=10) -> None: """Gracefully destroy the GrowthBook instance""" logger.debug("Starting GrowthBook destroy process") - + try: # Clean up plugins logger.debug("Cleaning up plugins") self._cleanup_plugins() except Exception as e: logger.warning(f"Error cleaning up plugins: {e}") - + try: logger.debug("Stopping auto refresh during destroy") self.stopAutoRefresh(timeout=timeout) except Exception as e: logger.warning(f"Error stopping auto refresh during destroy: {e}") - + try: # Stop background refresh operations if self._stale_while_revalidate and self._client_key: feature_repo.stop_background_refresh() except Exception as e: logger.warning(f"Error stopping background refresh during destroy: {e}") - + try: # Clean up feature update callback if self._client_key: feature_repo.remove_feature_update_callback(self._on_feature_update) except Exception as e: logger.warning(f"Error removing feature update callback: {e}") - + # Clear all internal state try: self._subscriptions.clear() @@ -1007,14 +968,14 @@ def get_feature_value(self, key: str, fallback): def evalFeature(self, key: str) -> FeatureResult: warnings.warn("evalFeature is deprecated, use eval_feature instead", DeprecationWarning) return self.eval_feature(key) - + def _ensure_fresh_features(self) -> None: """Lazy refresh: Check cache expiry and refresh if needed, but only if client_key is provided""" - + # Prevent infinite recursion when updating features (e.g., during sticky bucket refresh) if self._is_updating_features: return - + if self._streaming or self._stale_while_revalidate or not self._client_key: return # Skip cache checks - SSE or background refresh handles freshness @@ -1026,7 +987,7 @@ def _ensure_fresh_features(self) -> None: def _get_eval_context(self) -> EvaluationContext: # Lazy refresh: ensure features are fresh before evaluation self._ensure_fresh_features() - + # use the latest attributes for every evaluation. self._user_ctx.attributes = self._attributes self._user_ctx.url = self._url @@ -1040,8 +1001,8 @@ def _get_eval_context(self) -> EvaluationContext: ) def eval_feature(self, key: str) -> FeatureResult: - result = core_eval_feature(key=key, - evalContext=self._get_eval_context(), + result = core_eval_feature(key=key, + evalContext=self._get_eval_context(), callback_subscription=self._fireSubscriptions, tracking_cb=self._track ) @@ -1080,7 +1041,7 @@ def _fireSubscriptions(self, experiment: Experiment, result: Result): def run(self, experiment: Experiment) -> Result: # result = self._run(experiment) - result = run_experiment(experiment=experiment, + result = run_experiment(experiment=experiment, evalContext=self._get_eval_context(), tracking_cb=self._track ) @@ -1169,7 +1130,7 @@ def _initialize_plugins(self) -> None: def user_agent_suffix(self) -> Optional[str]: """Get the suffix appended to the User-Agent header""" return feature_repo.user_agent_suffix - + @user_agent_suffix.setter def user_agent_suffix(self, value: Optional[str]) -> None: """Set a suffix to be appended to the User-Agent header""" diff --git a/growthbook/growthbook_client.py b/growthbook/growthbook_client.py index 1ec8883..fe0bd6c 100644 --- a/growthbook/growthbook_client.py +++ b/growthbook/growthbook_client.py @@ -10,6 +10,7 @@ import traceback from datetime import datetime from growthbook import FeatureRepository, feature_repo +from growthbook.cache_interfaces import AbstractAsyncFeatureCache from contextlib import asynccontextmanager from .core import eval_feature as core_eval_feature, run_experiment @@ -43,9 +44,9 @@ def __call__(cls, *args, **kwargs): class BackoffStrategy: """Exponential backoff with jitter for failed requests""" def __init__( - self, - initial_delay: float = 1.0, - max_delay: float = 60.0, + self, + initial_delay: float = 1.0, + max_delay: float = 60.0, multiplier: float = 2.0, jitter: float = 0.1 ): @@ -59,7 +60,7 @@ def __init__( def next_delay(self) -> float: """Calculate next delay with jitter""" delay = min( - self.current_delay * (self.multiplier ** self.attempt), + self.current_delay * (self.multiplier ** self.attempt), self.max_delay ) # Add random jitter @@ -122,6 +123,7 @@ def __init__(self, self._callbacks: List[Callable[[Dict[str, Any]], Awaitable[None]]] = [] self._last_successful_refresh: Optional[datetime] = None self._refresh_in_progress = asyncio.Lock() + self.async_cache: Optional[AbstractAsyncFeatureCache] = None self.http_connect_timeout = http_connect_timeout self.http_read_timeout = http_read_timeout @@ -178,7 +180,7 @@ def remove_callback(self, callback: Callable[[Dict[str, Any]], Awaitable[None]]) self._callbacks.remove(callback) """ - _start_sse_refresh flow mimics a bridge pattern to connect a blocking, synchronous background thread + _start_sse_refresh flow mimics a bridge pattern to connect a blocking, synchronous background thread (the SSEClient) with your non-blocking, async main loop. Bridge - _maintain_sse_connection - runs on the main async loop, calls `startAutoRefresh` (which in turn spawns a thread) @@ -186,7 +188,7 @@ def remove_callback(self, callback: Callable[[Dict[str, Any]], Awaitable[None]]) The SSEClient runs in a separate thread, makes a blocking HTTP request, and invokes `on_event` synchronously. - The Hand off - when the event arrives (we're still on the background thread), sse_handler uses `asyncio.run_coroutine_threadsafe` + The Hand off - when the event arrives (we're still on the background thread), sse_handler uses `asyncio.run_coroutine_threadsafe` to schedule the async processing `_handle_sse_event` onto the main event loop. """ @@ -232,7 +234,7 @@ async def _maintain_sse_connection() -> None: try: # NOTE: `startAutoRefresh` is synchronous and starts a background thread. self.startAutoRefresh(self._api_host, self._client_key, sse_handler) - + # Wait indefinitely until the task is cancelled - basically saying "Keep this service 'active' until someone cancels me." # reconnection logic is handled inside SSEClient's thread await asyncio.Future() @@ -246,7 +248,7 @@ async def _maintain_sse_connection() -> None: # stopAutoRefresh blocks joining a thread, so it needs to be run in executor # to avoid blocking the async event loop await main_loop.run_in_executor( - None, + None, lambda: self.stopAutoRefresh(timeout=10) ) except Exception: @@ -303,7 +305,7 @@ async def refresh_loop() -> None: async def start_feature_refresh(self, strategy: FeatureRefreshStrategy, callback=None): """Initialize feature refresh based on strategy""" self._refresh_callback = callback - + if strategy == FeatureRefreshStrategy.SERVER_SENT_EVENTS: await self._start_sse_refresh() else: @@ -338,7 +340,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.stop_refresh() - + async def load_features_async( self, api_host: str, client_key: str, decryption_key: str = "", ttl: int = 60 ) -> Optional[Dict]: @@ -346,23 +348,43 @@ async def load_features_async( if api_host == self._api_host and client_key == self._client_key: decryption_key = self._decryption_key ttl = self._cache_ttl - return await super().load_features_async(api_host, client_key, decryption_key, ttl) + + key = api_host + "::" + client_key + + if self.async_cache: + cached = await self.async_cache.get(key) + if cached: + return cached + + res = await super().load_features_async(api_host, client_key, decryption_key, ttl) + + if res is not None and self.async_cache: + await self.async_cache.set(key, res, ttl) + + return res + + def set_async_cache(self, cache: AbstractAsyncFeatureCache) -> None: + """ + Set asynchronous cache implementation. + When set, load_features_async() will use this instead of sync cache. + """ + self.async_cache = cache class GrowthBookClient: def __init__( self, options: Optional[Union[Dict[str, Any], Options]] = None - ): + ): self.options = ( options if isinstance(options, Options) else Options(**options) if options else Options() ) - + # Thread-safe tracking state self._tracked: Dict[str, bool] = {} # Access only within async context self._tracked_lock = threading.Lock() - + # Thread-safe subscription management self._subscriptions: Set[Callable[[Experiment, Result], None]] = set() self._subscriptions_lock = threading.Lock() @@ -373,7 +395,7 @@ def __init__( 'assignments': {} } self._sticky_bucket_cache_lock = False - + # Plugin support self._tracking_plugins: List[Any] = self.options.tracking_plugins or [] self._initialized_plugins: List[Any] = [] @@ -390,10 +412,22 @@ def __init__( if self.options.client_key else None ) - + + # Check if repo was initialized + if self._features_repository is not None: + # 1. set sync cache + if self.options.cache is not None: + self._features_repository.set_cache(self.options.cache) + logger.debug("Custom sync cache set for FeatureRepository.") + + # 2. set async cache + if self.options.async_cache is not None: + self._features_repository.set_async_cache(self.options.async_cache) + logger.debug("Custom async cache set for FeatureRepository.") + self._global_context: Optional[GlobalContext] = None self._context_lock = asyncio.Lock() - + # Initialize plugins self._initialize_plugins() @@ -442,8 +476,8 @@ def _fire_subscriptions(self, experiment: Experiment, result: Result) -> None: async def set_features(self, features: dict) -> None: await self._feature_update_callback({"features": features}) - - + + async def _refresh_sticky_buckets(self, attributes: Dict[str, Any]) -> Dict[str, Any]: """Refresh sticky bucket assignments only if attributes have changed""" if not self.options.sticky_bucket_service: @@ -453,7 +487,7 @@ async def _refresh_sticky_buckets(self, attributes: Dict[str, Any]) -> Dict[str, while not self._sticky_bucket_cache_lock: if attributes == self._sticky_bucket_cache['attributes']: return self._sticky_bucket_cache['assignments'] - + self._sticky_bucket_cache_lock = True try: assignments = self.options.sticky_bucket_service.get_all_assignments(attributes) @@ -462,7 +496,7 @@ async def _refresh_sticky_buckets(self, attributes: Dict[str, Any]) -> Dict[str, return assignments finally: self._sticky_bucket_cache_lock = False - + # Fallback return for edge case where loop condition is never satisfied return {} @@ -475,9 +509,9 @@ async def initialize(self) -> bool: try: # Initial feature load initial_features = await self._features_repository.load_features_async( - self.options.api_host or "https://cdn.growthbook.io", - self.options.client_key or "", - self.options.decryption_key or "", + self.options.api_host or "https://cdn.growthbook.io", + self.options.client_key or "", + self.options.decryption_key or "", self.options.cache_ttl ) if not initial_features: @@ -486,15 +520,15 @@ async def initialize(self) -> bool: # Create global context with initial features await self._feature_update_callback(initial_features) - + # Set up callback for future updates self._features_repository.add_callback(self._feature_update_callback) - + # Start feature refresh refresh_strategy = self.options.refresh_strategy or FeatureRefreshStrategy.STALE_WHILE_REVALIDATE await self._features_repository.start_feature_refresh(refresh_strategy) return True - + except Exception as e: logger.error(f"Initialization failed: {str(e)}", exc_info=True) traceback.print_exc() @@ -541,10 +575,10 @@ async def create_evaluation_context(self, user_context: UserContext) -> Evaluati """Create evaluation context for feature evaluation""" if self._global_context is None: raise RuntimeError("GrowthBook client not properly initialized") - + # Get sticky bucket assignments if needed sticky_assignments = await self._refresh_sticky_buckets(user_context.attributes) - + # update user context with sticky bucket assignments user_context.sticky_bucket_assignment_docs = sticky_assignments @@ -579,7 +613,7 @@ async def is_on(self, key: str, user_context: UserContext) -> bool: except Exception: logger.exception("Error in feature usage callback") return result.on - + async def is_off(self, key: str, user_context: UserContext) -> bool: """Check if a feature is set to off with proper async context management""" async with self._context_lock: @@ -592,7 +626,7 @@ async def is_off(self, key: str, user_context: UserContext) -> bool: except Exception: logger.exception("Error in feature usage callback") return result.off - + async def get_feature_value(self, key: str, fallback: Any, user_context: UserContext) -> Any: async with self._context_lock: context = await self.create_evaluation_context(user_context) @@ -610,14 +644,14 @@ async def run(self, experiment: Experiment, user_context: UserContext) -> Result async with self._context_lock: context = await self.create_evaluation_context(user_context) result = run_experiment( - experiment=experiment, + experiment=experiment, evalContext=context, tracking_cb=self._track ) # Fire subscriptions synchronously self._fire_subscriptions(experiment, result) return result - + async def close(self) -> None: """Clean shutdown with proper cleanup""" if self._features_repository: @@ -631,7 +665,7 @@ async def close(self) -> None: # Clear context async with self._context_lock: self._global_context = None - + # Cleanup plugins self._cleanup_plugins() @@ -639,7 +673,7 @@ async def close(self) -> None: def user_agent_suffix(self) -> Optional[str]: """Get the suffix appended to the User-Agent header""" return feature_repo.user_agent_suffix - + @user_agent_suffix.setter def user_agent_suffix(self, value: Optional[str]) -> None: """Set a suffix to be appended to the User-Agent header""" @@ -673,4 +707,4 @@ def _cleanup_plugins(self) -> None: logger.debug(f"Cleaned up plugin: {plugin.__class__.__name__}") except Exception as e: logger.error(f"Error cleaning up plugin {plugin}: {e}") - self._initialized_plugins.clear() \ No newline at end of file + self._initialized_plugins.clear() diff --git a/tests/test_growthbook.py b/tests/test_growthbook.py index f06bd6a..9ed8f51 100644 --- a/tests/test_growthbook.py +++ b/tests/test_growthbook.py @@ -3,7 +3,6 @@ import json import os from growthbook import ( - FeatureRule, GrowthBook, Experiment, Feature, @@ -12,6 +11,7 @@ feature_repo, logger, ) +from growthbook.common_types import FeatureRule from growthbook.core import ( getBucketRanges, @@ -160,7 +160,7 @@ def test_stickyBucket(stickyBucket_data): gb = GrowthBook(**ctx) res = gb.eval_feature(key) - + if not res.experimentResult: assert None == expected_result else: @@ -218,10 +218,10 @@ def test_tracking(): def test_feature_usage_callback(): """Test that feature usage callback is called correctly""" calls = [] - + def feature_usage_cb(key, result, user_context): calls.append([key, result, user_context]) - + gb = GrowthBook( attributes={"id": "1"}, on_feature_usage=feature_usage_cb, @@ -236,7 +236,7 @@ def feature_usage_cb(key, result, user_context): ), } ) - + # Test eval_feature result1 = gb.eval_feature("feature-1") assert len(calls) == 1 @@ -244,14 +244,14 @@ def feature_usage_cb(key, result, user_context): assert calls[0][1].value is True assert calls[0][1].source == "defaultValue" assert calls[0][2].attributes == {"id": "1"} - + # Test is_on gb.is_on("feature-2") assert len(calls) == 2 assert calls[1][0] == "feature-2" assert calls[1][1].value is False assert calls[1][2].attributes == {"id": "1"} - + # Test get_feature_value value = gb.get_feature_value("feature-3", "blue") assert len(calls) == 3 @@ -259,27 +259,27 @@ def feature_usage_cb(key, result, user_context): assert calls[2][1].value == "red" assert value == "red" assert calls[2][2].attributes == {"id": "1"} - + # Test is_off gb.is_off("feature-1") assert len(calls) == 4 assert calls[3][0] == "feature-1" assert calls[3][2].attributes == {"id": "1"} - + # Calling same feature multiple times should trigger callback each time gb.eval_feature("feature-1") gb.eval_feature("feature-1") assert len(calls) == 6 - + gb.destroy() def test_feature_usage_callback_error_handling(): """Test that feature usage callback errors are handled gracefully""" - + def failing_callback(key, result, user_context): raise Exception("Callback error") - + gb = GrowthBook( attributes={"id": "1"}, on_feature_usage=failing_callback, @@ -287,14 +287,14 @@ def failing_callback(key, result, user_context): "feature-1": Feature(defaultValue=True), } ) - + # Should not raise an error even if callback fails result = gb.eval_feature("feature-1") assert result.value is True - + # Should work with is_on as well assert gb.is_on("feature-1") is True - + gb.destroy() @@ -324,7 +324,7 @@ def test_handles_weird_experiment_values(): def test_skip_all_experiments_flag(): """Test that skip_all_experiments flag prevents users from being put into experiments""" - + # Test with skip_all_experiments=True gb_skip = GrowthBook( attributes={"id": "1"}, @@ -342,22 +342,22 @@ def test_skip_all_experiments_flag(): ) } ) - + # User should NOT be in experiment due to skip_all_experiments flag result = gb_skip.eval_feature("feature-with-experiment") assert result.value == "control" # Should get default value assert result.source == "defaultValue" assert result.experiment is None # No experiment should be assigned assert result.experimentResult is None - + # Test running experiment directly exp = Experiment(key="direct-exp", variations=["a", "b"]) exp_result = gb_skip.run(exp) assert exp_result.inExperiment is False assert exp_result.value == "a" # Should get first variation (control) - + gb_skip.destroy() - + # Test with skip_all_experiments=False (default behavior) gb_normal = GrowthBook( attributes={"id": "1"}, @@ -375,13 +375,13 @@ def test_skip_all_experiments_flag(): ) } ) - + # User SHOULD be in experiment normally result_normal = gb_normal.eval_feature("feature-with-experiment") # With id="1", this user should be assigned a variation assert result_normal.value in ["control", "variation"] assert result_normal.source == "experiment" - + gb_normal.destroy() def test_force_variation(): @@ -1087,39 +1087,39 @@ def test_ttl_automatic_feature_refresh(mocker): {"features": {"test_feature": {"defaultValue": False}}, "savedGroups": {}}, {"features": {"test_feature": {"defaultValue": True}}, "savedGroups": {}} ] - + call_count = 0 def mock_fetch_features(api_host, client_key, decryption_key=""): nonlocal call_count response = mock_responses[min(call_count, len(mock_responses) - 1)] call_count += 1 return response - + # Clear cache and mock the fetch method feature_repo.clear_cache() m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) - + # Create GrowthBook instance with short TTL gb = GrowthBook( api_host="https://cdn.growthbook.io", client_key="test-key", cache_ttl=1 # 1 second TTL for testing ) - + try: # Initial evaluation - should trigger first load assert gb.is_on('test_feature') == False assert call_count == 1 - + # Manually expire the cache by setting expiry time to past cache_key = "https://cdn.growthbook.io::test-key" if hasattr(feature_repo.cache, 'cache') and cache_key in feature_repo.cache.cache: feature_repo.cache.cache[cache_key].expires = time() - 10 - + # Next evaluation should automatically refresh cache and update features assert gb.is_on('test_feature') == True assert call_count == 2 - + finally: gb.destroy() feature_repo.clear_cache() @@ -1131,42 +1131,42 @@ def test_multiple_instances_get_updated_on_cache_expiry(mocker): {"features": {"test_feature": {"defaultValue": "v1"}}, "savedGroups": {}}, {"features": {"test_feature": {"defaultValue": "v2"}}, "savedGroups": {}} ] - + call_count = 0 def mock_fetch_features(api_host, client_key, decryption_key=""): nonlocal call_count response = mock_responses[min(call_count, len(mock_responses) - 1)] call_count += 1 return response - + feature_repo.clear_cache() m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) - + # Create multiple GrowthBook instances gb1 = GrowthBook(api_host="https://cdn.growthbook.io", client_key="test-key") gb2 = GrowthBook(api_host="https://cdn.growthbook.io", client_key="test-key") - + try: # Initial evaluation from first instance - should trigger first load assert gb1.get_feature_value('test_feature', 'default') == "v1" assert call_count == 1 - + # Second instance should use cached value (no additional API call) assert gb2.get_feature_value('test_feature', 'default') == "v1" assert call_count == 1 # Still 1, used cache - + # Manually expire the cache cache_key = "https://cdn.growthbook.io::test-key" if hasattr(feature_repo.cache, 'cache') and cache_key in feature_repo.cache.cache: feature_repo.cache.cache[cache_key].expires = time() - 10 - + # Next evaluation should automatically refresh and notify both instances via callbacks assert gb1.get_feature_value('test_feature', 'default') == "v2" assert call_count == 2 - + # Second instance should also have the updated value due to callbacks assert gb2.get_feature_value('test_feature', 'default') == "v2" - + finally: gb1.destroy() gb2.destroy() @@ -1180,18 +1180,18 @@ def test_stale_while_revalidate_basic_functionality(mocker): {"features": {"test_feature": {"defaultValue": "v1"}}, "savedGroups": {}}, {"features": {"test_feature": {"defaultValue": "v2"}}, "savedGroups": {}} ] - + call_count = 0 def mock_fetch_features(api_host, client_key, decryption_key=""): nonlocal call_count response = mock_responses[min(call_count, len(mock_responses) - 1)] call_count += 1 return response - + # Clear cache and mock the fetch method feature_repo.clear_cache() m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) - + # Create GrowthBook instance with stale-while-revalidate enabled and short refresh interval gb = GrowthBook( api_host="https://cdn.growthbook.io", @@ -1200,22 +1200,22 @@ def mock_fetch_features(api_host, client_key, decryption_key=""): stale_while_revalidate=True, stale_ttl=1 # 1 second refresh interval for testing ) - + try: # Initial evaluation - should use initial loaded data assert gb.get_feature_value('test_feature', 'default') == "v1" assert call_count == 1 # Initial load - + # Wait for background refresh to happen import time as time_module time_module.sleep(1.5) # Wait longer than refresh interval - + # Should have triggered background refresh assert call_count >= 2 - + # Next evaluation should get updated data from background refresh assert gb.get_feature_value('test_feature', 'default') == "v2" - + finally: gb.destroy() feature_repo.clear_cache() @@ -1224,17 +1224,17 @@ def mock_fetch_features(api_host, client_key, decryption_key=""): def test_stale_while_revalidate_starts_background_task(mocker): """Test that stale-while-revalidate starts background refresh task""" mock_response = {"features": {"test_feature": {"defaultValue": "fresh"}}, "savedGroups": {}} - + call_count = 0 def mock_fetch_features(api_host, client_key, decryption_key=""): nonlocal call_count call_count += 1 return mock_response - + # Clear cache and mock the fetch method feature_repo.clear_cache() m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) - + # Create GrowthBook instance with stale-while-revalidate enabled gb = GrowthBook( api_host="https://cdn.growthbook.io", @@ -1242,16 +1242,16 @@ def mock_fetch_features(api_host, client_key, decryption_key=""): stale_while_revalidate=True, stale_ttl=5 ) - + try: # Should have started background refresh task assert feature_repo._refresh_thread is not None assert feature_repo._refresh_thread.is_alive() - + # Initial evaluation should work assert gb.get_feature_value('test_feature', 'default') == "fresh" assert call_count == 1 # Initial load - + finally: gb.destroy() feature_repo.clear_cache() @@ -1259,17 +1259,17 @@ def mock_fetch_features(api_host, client_key, decryption_key=""): def test_stale_while_revalidate_disabled_fallback(mocker): """Test that when stale_while_revalidate is disabled, it falls back to normal behavior""" mock_response = {"features": {"test_feature": {"defaultValue": "normal"}}, "savedGroups": {}} - + call_count = 0 def mock_fetch_features(api_host, client_key, decryption_key=""): nonlocal call_count call_count += 1 return mock_response - + # Clear cache and mock the fetch method feature_repo.clear_cache() m = mocker.patch.object(feature_repo, '_fetch_features', side_effect=mock_fetch_features) - + # Create GrowthBook instance with stale-while-revalidate disabled (default) gb = GrowthBook( api_host="https://cdn.growthbook.io", @@ -1277,24 +1277,24 @@ def mock_fetch_features(api_host, client_key, decryption_key=""): cache_ttl=1, # Short TTL stale_while_revalidate=False # Explicitly disabled ) - + try: # Should NOT have started background refresh task assert feature_repo._refresh_thread is None - + # Initial evaluation assert gb.get_feature_value('test_feature', 'default') == "normal" assert call_count == 1 - + # Manually expire the cache cache_key = "https://cdn.growthbook.io::test-key" if hasattr(feature_repo.cache, 'cache') and cache_key in feature_repo.cache.cache: feature_repo.cache.cache[cache_key].expires = time() - 10 - + # Next evaluation should fetch synchronously (normal behavior) assert gb.get_feature_value('test_feature', 'default') == "normal" assert call_count == 2 # Should have fetched again - + finally: gb.destroy() feature_repo.clear_cache() @@ -1303,31 +1303,31 @@ def mock_fetch_features(api_host, client_key, decryption_key=""): def test_stale_while_revalidate_cleanup(mocker): """Test that background refresh is properly cleaned up""" mock_response = {"features": {"test_feature": {"defaultValue": "test"}}, "savedGroups": {}} - + # Mock the fetch method feature_repo.clear_cache() m = mocker.patch.object(feature_repo, '_fetch_features', return_value=mock_response) - + # Create GrowthBook instance with stale-while-revalidate enabled gb = GrowthBook( api_host="https://cdn.growthbook.io", client_key="test-key", stale_while_revalidate=True ) - + try: # Should have started background refresh task assert feature_repo._refresh_thread is not None assert feature_repo._refresh_thread.is_alive() - + # Destroy should clean up the background task gb.destroy() - + # Background task should be stopped assert feature_repo._refresh_thread is None or not feature_repo._refresh_thread.is_alive() - + finally: # Ensure cleanup even if test fails if feature_repo._refresh_thread: feature_repo.stop_background_refresh() - feature_repo.clear_cache() \ No newline at end of file + feature_repo.clear_cache() diff --git a/tests/test_growthbook_client.py b/tests/test_growthbook_client.py index c88b024..60db77f 100644 --- a/tests/test_growthbook_client.py +++ b/tests/test_growthbook_client.py @@ -12,7 +12,7 @@ class AsyncMock(MagicMock): async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) -from growthbook import InMemoryStickyBucketService +from growthbook import InMemoryStickyBucketService, AbstractAsyncFeatureCache import pytest import asyncio import os @@ -20,7 +20,7 @@ async def __call__(self, *args, **kwargs): from growthbook.common_types import Experiment, Options from growthbook.growthbook_client import ( - GrowthBookClient, + GrowthBookClient, UserContext, FeatureRefreshStrategy, EnhancedFeatureRepository @@ -1047,6 +1047,75 @@ async def test_skip_all_experiments_flag(): # User should be assigned to a variation assert result_normal.value in ["control", "variation"] assert result_normal.source == "experiment" - + finally: - await client.close() \ No newline at end of file + await client.close() + + +@pytest.fixture +def mock_async_cache(): + """A mock AbstractAsyncFeatureCache for testing""" + cache = AsyncMock(spec=AbstractAsyncFeatureCache) + cache.get = AsyncMock(return_value=None) + cache.set = AsyncMock(return_value=None) + return cache + + +@pytest.mark.asyncio +async def test_async_cache_hit_skips_http(mock_async_cache, mock_features_response): + """Cache HIT: async_cache.get() returns data, no HTTP call should be made""" + mock_async_cache.get.return_value = mock_features_response + + repo = EnhancedFeatureRepository( + api_host="https://test.growthbook.io", + client_key="test_key" + ) + repo.set_async_cache(mock_async_cache) + + with patch('growthbook.FeatureRepository.load_features_async', new_callable=AsyncMock) as mock_super: + result = await repo.load_features_async("https://test.growthbook.io", "test_key") + + assert result == mock_features_response + mock_async_cache.get.assert_called_once_with("https://test.growthbook.io::test_key") + mock_super.assert_not_called() + + +@pytest.mark.asyncio +async def test_async_cache_miss_fetches_and_populates(mock_async_cache, mock_features_response): + """Cache MISS: async_cache.get() returns None, HTTP fetch happens and result is stored""" + mock_async_cache.get.return_value = None + + repo = EnhancedFeatureRepository( + api_host="https://test.growthbook.io", + client_key="test_key" + ) + repo.set_async_cache(mock_async_cache) + + with patch('growthbook.FeatureRepository.load_features_async', + new_callable=AsyncMock, return_value=mock_features_response): + result = await repo.load_features_async("https://test.growthbook.io", "test_key") + + assert result == mock_features_response + mock_async_cache.set.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_cache_ttl_passed_correctly(mock_async_cache, mock_features_response): + """TTL from cache_ttl is passed to async_cache.set()""" + mock_async_cache.get.return_value = None + expected_ttl = 120 + + repo = EnhancedFeatureRepository( + api_host="https://test.growthbook.io", + client_key="test_key", + cache_ttl=expected_ttl + ) + repo.set_async_cache(mock_async_cache) + + with patch('growthbook.FeatureRepository.load_features_async', + new_callable=AsyncMock, return_value=mock_features_response): + await repo.load_features_async("https://test.growthbook.io", "test_key") + + _, call_kwargs = mock_async_cache.set.call_args + call_args = mock_async_cache.set.call_args[0] + assert call_args[2] == expected_ttl \ No newline at end of file