|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import logging |
| 6 | +import time |
| 7 | +from collections import defaultdict |
| 8 | +from collections.abc import Callable |
| 9 | +from dataclasses import dataclass |
6 | 10 | from typing import Any, Protocol |
7 | 11 |
|
8 | 12 | from .enums import SafetyClass, SensitivityTag |
9 | | -from .errors import PolicyDenied |
| 13 | +from .errors import AgentKernelError, PolicyDenied |
10 | 14 | from .models import Capability, CapabilityRequest, PolicyDecision, Principal |
11 | 15 |
|
12 | 16 | logger = logging.getLogger(__name__) |
|
18 | 22 | _MAX_ROWS_USER = 50 |
19 | 23 | _MAX_ROWS_SERVICE = 500 |
20 | 24 |
|
| 25 | +# Default rate limits per safety class: (invocations, window_seconds). |
| 26 | +_DEFAULT_RATE_LIMITS: dict[SafetyClass, tuple[int, float]] = { |
| 27 | + SafetyClass.READ: (60, 60.0), |
| 28 | + SafetyClass.WRITE: (10, 60.0), |
| 29 | + SafetyClass.DESTRUCTIVE: (2, 60.0), |
| 30 | +} |
| 31 | + |
| 32 | +# Service role multiplier for rate limits. |
| 33 | +_SERVICE_RATE_MULTIPLIER = 10 |
| 34 | + |
| 35 | + |
| 36 | +@dataclass(slots=True) |
| 37 | +class _RateEntry: |
| 38 | + """Timestamps for a single rate-limit key.""" |
| 39 | + |
| 40 | + timestamps: list[float] |
| 41 | + |
| 42 | + |
| 43 | +class RateLimiter: |
| 44 | + """Sliding-window rate limiter using monotonic clock. |
| 45 | +
|
| 46 | + Args: |
| 47 | + clock: Callable returning the current time in seconds. |
| 48 | + Defaults to :func:`time.monotonic`. |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self, clock: Callable[[], float] | None = None) -> None: |
| 52 | + self._clock = clock or time.monotonic |
| 53 | + self._windows: dict[str, _RateEntry] = defaultdict(lambda: _RateEntry(timestamps=[])) |
| 54 | + |
| 55 | + def check(self, key: str, limit: int, window_seconds: float) -> bool: |
| 56 | + """Return ``True`` if the next invocation would be within the limit. |
| 57 | +
|
| 58 | + Prunes expired timestamps as a side-effect. |
| 59 | +
|
| 60 | + Args: |
| 61 | + key: Rate-limit key (e.g. ``"principal:capability"``). |
| 62 | + limit: Maximum allowed invocations per window. |
| 63 | + window_seconds: Sliding window duration in seconds. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + ``True`` if under limit, ``False`` if limit would be exceeded. |
| 67 | + """ |
| 68 | + now = self._clock() |
| 69 | + cutoff = now - window_seconds |
| 70 | + entry = self._windows[key] |
| 71 | + entry.timestamps = [t for t in entry.timestamps if t > cutoff] |
| 72 | + if not entry.timestamps: |
| 73 | + del self._windows[key] |
| 74 | + return True |
| 75 | + return len(entry.timestamps) < limit |
| 76 | + |
| 77 | + def record(self, key: str) -> None: |
| 78 | + """Record an invocation for *key*. |
| 79 | +
|
| 80 | + Args: |
| 81 | + key: Rate-limit key. |
| 82 | + """ |
| 83 | + self._windows[key].timestamps.append(self._clock()) |
| 84 | + |
21 | 85 |
|
22 | 86 | class PolicyEngine(Protocol): |
23 | 87 | """Interface for a policy engine.""" |
@@ -61,8 +125,40 @@ class DefaultPolicyEngine: |
61 | 125 | ``"secrets_reader"`` and a justification of at least 15 characters. |
62 | 126 | 6. **max_rows** — 50 for regular users; 500 for principals with the |
63 | 127 | ``"service"`` role. |
| 128 | + 7. **Rate limiting** — sliding-window rate limit per |
| 129 | + ``(principal_id, capability_id)`` pair, with defaults by safety class. |
| 130 | + Principals with the ``"service"`` role get 10× the default limits. |
64 | 131 | """ |
65 | 132 |
|
| 133 | + def __init__( |
| 134 | + self, |
| 135 | + *, |
| 136 | + rate_limits: dict[SafetyClass, tuple[int, float]] | None = None, |
| 137 | + clock: Callable[[], float] | None = None, |
| 138 | + ) -> None: |
| 139 | + """Initialise the policy engine. |
| 140 | +
|
| 141 | + Args: |
| 142 | + rate_limits: Override default rate limits per safety class. |
| 143 | + Each value is ``(max_invocations, window_seconds)``. |
| 144 | + Partial overrides are merged into the defaults so that |
| 145 | + unspecified safety classes retain their default limits. |
| 146 | + clock: Monotonic clock callable for rate-limiter. |
| 147 | + Defaults to :func:`time.monotonic`. |
| 148 | + """ |
| 149 | + limits = dict(_DEFAULT_RATE_LIMITS) |
| 150 | + if rate_limits is not None: |
| 151 | + limits.update(rate_limits) |
| 152 | + for sc, (count, window) in limits.items(): |
| 153 | + if count < 1 or window <= 0: |
| 154 | + raise AgentKernelError( |
| 155 | + f"Invalid rate limit for {sc.value}: " |
| 156 | + f"limit must be >= 1 and window must be > 0, " |
| 157 | + f"got limit={count}, window={window}." |
| 158 | + ) |
| 159 | + self._rate_limits = limits |
| 160 | + self._limiter = RateLimiter(clock=clock) |
| 161 | + |
66 | 162 | @staticmethod |
67 | 163 | def _deny(reason: str, *, principal_id: str, capability_id: str) -> PolicyDenied: |
68 | 164 | """Log a policy denial at WARNING and return the exception to raise.""" |
@@ -197,6 +293,22 @@ def evaluate( |
197 | 293 | else: |
198 | 294 | constraints["max_rows"] = max_rows |
199 | 295 |
|
| 296 | + # ── Rate limiting ───────────────────────────────────────────────── |
| 297 | + |
| 298 | + rate_key = f"{pid}:{cid}" |
| 299 | + if capability.safety_class in self._rate_limits: |
| 300 | + limit, window = self._rate_limits[capability.safety_class] |
| 301 | + if "service" in roles: |
| 302 | + limit *= _SERVICE_RATE_MULTIPLIER |
| 303 | + if not self._limiter.check(rate_key, limit, window): |
| 304 | + raise self._deny( |
| 305 | + f"Rate limit exceeded: {limit} {capability.safety_class.value} " |
| 306 | + f"invocations per {window}s for principal '{pid}'", |
| 307 | + principal_id=pid, |
| 308 | + capability_id=cid, |
| 309 | + ) |
| 310 | + self._limiter.record(rate_key) |
| 311 | + |
200 | 312 | reason = "Request approved by DefaultPolicyEngine." |
201 | 313 | logger.info( |
202 | 314 | "policy_allowed", |
|
0 commit comments