Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# https://peps.python.org/pep-0440/

[tool.bumpversion]
current_version = "0.4.4"
current_version = "0.4.5"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"grpcio-status==1.78.0",
"pydantic==2.12.5",
]
version = "0.4.4"
version = "0.4.5"

[project.optional-dependencies]
profiling = [
Expand Down
2 changes: 1 addition & 1 deletion src/digitalkin/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
try:
__version__ = version("digitalkin")
except PackageNotFoundError:
__version__ = "0.4.4"
__version__ = "0.4.5"
19 changes: 13 additions & 6 deletions src/digitalkin/grpc_servers/module_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import os
import time
from argparse import ArgumentParser, Namespace
from collections.abc import AsyncGenerator
from typing import Any, cast
Expand Down Expand Up @@ -86,8 +87,9 @@ def __init__(self, module_class: type[BaseModule]) -> None:
extra={"job_manager": self.job_manager},
)
self.setup = GrpcSetup() if self.args.services_mode == ServicesMode.REMOTE else DefaultSetup()
self._setup_cache: dict[str, SetupVersionData] = {}
self._setup_cache: dict[str, tuple[float, SetupVersionData]] = {}
self._setup_cache_max = int(os.environ.get("DIGITALKIN_SETUP_CACHE_MAX", "100"))
self._setup_cache_ttl = float(os.environ.get("DIGITALKIN_SETUP_CACHE_TTL", "600.0"))
self._setup_inflight: dict[str, asyncio.Future[SetupVersionData]] = {}
self._completion_timeout = float(os.environ.get("DIGITALKIN_COMPLETION_TIMEOUT", "300.0"))

Expand Down Expand Up @@ -129,11 +131,11 @@ def _get_registry(self) -> RegistryStrategy | None:
return self._registry_cache

def _cache_setup(self, setup_id: str, version_data: SetupVersionData) -> None:
"""Cache setup version data, evicting oldest entry if at capacity."""
"""Cache setup version data with fetch timestamp, evicting oldest entry if at capacity."""
if len(self._setup_cache) >= self._setup_cache_max:
oldest_key = next(iter(self._setup_cache))
del self._setup_cache[oldest_key]
self._setup_cache[setup_id] = version_data
self._setup_cache[setup_id] = (time.monotonic(), version_data)

async def _resolve_setup(self, setup_id: str, mission_id: str) -> SetupVersionData:
"""Return setup version data from cache or remote service.
Expand All @@ -151,10 +153,13 @@ async def _resolve_setup(self, setup_id: str, mission_id: str) -> SetupVersionDa
ServerError: gRPC communication failed.
ValidationError: Setup data failed validation.
"""
# Fast path: cache hit
# Fast path: cache hit within TTL
if (cached := self._setup_cache.get(setup_id)) is not None:
logger.debug("debug:_resolve_setup cache hit setup_id=%s", setup_id)
return cached
if time.monotonic() - cached[0] < self._setup_cache_ttl:
logger.debug("debug:_resolve_setup cache hit setup_id=%s", setup_id)
return cached[1]
del self._setup_cache[setup_id]
logger.debug("debug:_resolve_setup cache expired setup_id=%s", setup_id)

# Coalesce concurrent misses: first caller fetches, others await the same future
if setup_id in self._setup_inflight:
Expand Down Expand Up @@ -220,6 +225,8 @@ async def ConfigSetupModule(
},
)
setup_version = request.setup_version
# Invalidate cached setup so concurrent/subsequent starts refetch the reconfigured version
self._setup_cache.pop(setup_version.setup_id, None)
config_setup_data = self.module_class.create_config_setup_model(json_format.MessageToDict(request.content))
setup_version_data = await self.module_class.create_setup_model(
json_format.MessageToDict(request.setup_version.content),
Expand Down
66 changes: 66 additions & 0 deletions tests/grpc_server/test_module_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import asyncio
import time
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
Expand All @@ -22,6 +23,7 @@
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
from digitalkin.grpc_servers.module_servicer import ModuleServicer
from digitalkin.modules._base_module import BaseModule
from digitalkin.services.setup.setup_strategy import SetupVersionData
from tests.fixtures.grpc_fixtures import FakeContext


Expand Down Expand Up @@ -116,6 +118,7 @@ def module_servicer(mock_job_manager, mock_setup_strategy):
servicer.setup = mock_setup_strategy
servicer._setup_cache = {}
servicer._setup_cache_max = 100
servicer._setup_cache_ttl = 60.0
servicer._setup_inflight: dict[str, asyncio.Future] = {}
servicer._completion_timeout = 300.0

Expand Down Expand Up @@ -281,6 +284,69 @@ async def mock_stream_with_exception() -> AsyncGenerator[dict[str, Any], None]:
pass


class TestSetupCache:
"""Tests for setup cache TTL and invalidation behavior."""

@pytest.mark.asyncio
async def test_resolve_setup_cache_hit_within_ttl(self, module_servicer):
"""Test cache hit within TTL does not call the setup service."""
cached_version = SetupVersionData.model_construct(
id="version-123",
setup_id="setup-123",
content={"test": "setup"},
)
module_servicer._setup_cache["setup-123"] = (time.monotonic(), cached_version)

result = await module_servicer._resolve_setup("setup-123", "mission-456")

assert result is cached_version
module_servicer.setup.get_setup.assert_not_called()

@pytest.mark.asyncio
async def test_resolve_setup_cache_expired_refetches(self, module_servicer):
"""Test expired cache entry triggers a refetch and is replaced."""
stale_version = SetupVersionData.model_construct(
id="old-version",
setup_id="setup-123",
content={"old": "setup"},
)
module_servicer._setup_cache["setup-123"] = (time.monotonic() - 120.0, stale_version)

result = await module_servicer._resolve_setup("setup-123", "mission-456")

module_servicer.setup.get_setup.assert_called_once_with({"setup_id": "setup-123", "mission_id": "mission-456"})
assert result.id == "version-123"
assert module_servicer._setup_cache["setup-123"][1] is result

@pytest.mark.asyncio
async def test_config_setup_module_invalidates_cache(self, module_servicer, fake_context):
"""Test ConfigSetupModule removes the stale entry and re-caches updated content."""
stale_version = SetupVersionData.model_construct(
id="old-version",
setup_id="setup-123",
content={"old": "setup"},
)
module_servicer._setup_cache["setup-123"] = (time.monotonic(), stale_version)

setup_version = setup_pb2.SetupVersion(
id="version-123",
setup_id="setup-123",
content=json_format.ParseDict({"existing": "config"}, struct_pb2.Struct()),
)
request = lifecycle_pb2.ConfigSetupModuleRequest(
mission_id="mission-456",
setup_version=setup_version,
content=json_format.ParseDict({"new": "config"}, struct_pb2.Struct()),
)

response = await module_servicer.ConfigSetupModule(request, fake_context)

assert response.success is True
cached = module_servicer._setup_cache["setup-123"][1]
assert cached.id == "version-123"
assert cached.content == {"updated": "config"}


class TestStopModule:
"""Tests for StopModule endpoint."""

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading