diff --git a/.konjo/scripts/dry_check.py b/.konjo/scripts/dry_check.py index ce43e54..a385dec 100644 --- a/.konjo/scripts/dry_check.py +++ b/.konjo/scripts/dry_check.py @@ -29,7 +29,6 @@ import argparse import hashlib import json -import os import re import subprocess import sys diff --git a/.konjo/scripts/konjo_review.py b/.konjo/scripts/konjo_review.py index dc20648..4122590 100644 --- a/.konjo/scripts/konjo_review.py +++ b/.konjo/scripts/konjo_review.py @@ -5,7 +5,11 @@ Exit codes: 0=APPROVED/WARNING, 1=BLOCKER, 2=API error """ from __future__ import annotations -import argparse, json, os, sys, time +import argparse +import json +import os +import sys +import time from pathlib import Path CRITIC_MODEL = "claude-opus-4-6" @@ -68,7 +72,7 @@ def _call_api(diff_text: str, anthropic_module) -> dict: usage = response.usage print(f"[konjo-review] tokens: input={usage.input_tokens} output={usage.output_tokens} cache_read={getattr(usage, 'cache_read_input_tokens', 0)}", file=sys.stderr) return json.loads(raw) - except (anthropic_module.RateLimitError, anthropic_module.APIStatusError) as exc: + except (anthropic_module.RateLimitError, anthropic_module.APIStatusError): if attempt < MAX_RETRIES - 1: delay = RETRY_BASE_DELAY * (2**attempt) print(f"[konjo-review] retrying in {delay:.0f}s...", file=sys.stderr) diff --git a/CLAUDE.md b/CLAUDE.md index 2368aae..1175f3b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -2,7 +2,7 @@ Episodic memory engine for LLMs — persistent, associative, and beyond context windows. Hyperdimensional Computing (HDC) in Rust with Python bindings, temporal decay, semantic consolidation, and OpenAI-compatible memory middleware. -**v0.4.0** (Python) / **v0.1.0** (Rust) — 69 tests passing (1 pre-existing skip). +**v0.11.0** (Python) / **v0.1.0** (Rust) — 404 tests passing. ## Stack Rust 2021 · rand · serde · anyhow · clap · PyO3 (optional, `--features python`) · Python 3.9+ · NumPy · asyncio · hatchling diff --git a/PLAN.md b/PLAN.md index 4db6fd5..311a8da 100644 --- a/PLAN.md +++ b/PLAN.md @@ -1,6 +1,6 @@ # Kohaku — Development Plan -## Current Version: v0.9.0 +## Current Version: v0.10.0 ## Phase 1: Core HDC Engine (v0.1.0) ✅ - [x] Hypervector arithmetic: random, bundle, bind, permute @@ -149,3 +149,12 @@ Three P2 features that turn the kohaku store into an observable, debuggable prod - `DELETE /memories/stale?days=30&dry_run=true` - [x] Tests: 45 new (`test_provenance.py` 14, `test_time_filter.py` 16, `test_memory_health.py` 15). Total **327 passed**. - [x] `__init__.py` exports `ProvenanceGraph`, `ProvenanceNode`, `ProvenanceGraphResult`, `TimeFilter`, `TimelineBucket`, `apply_time_filter`, `bucket_timeline`, `filter_recent`, `MemoryHealthAnalyzer`, `MemoryHealthReport`, `DuplicatePair`, `StaleMemory`. + +## Phase 13: P2 Features — Episodic Binding, Chaining, Validation (v0.11.0) ✅ + +- [x] `python/kohaku/episode.py` — `EpisodeStore` with role-binding. `store_episode(label, *, who, what, when, where)` binds provided role HVs into a composite via `bundle(bind(R_role, value_hv), ...)`. Fixed deterministic role HVs (`_ROLE_SEEDS`) so any two stores over the same dims share the same role space. `query_episode(*, who, what, when, where, top_k)` retrieves from any partial cue. `unbind_role(entry_id, role)` returns the original HV. 17 unit tests in `python/tests/test_episode.py`. +- [x] `python/kohaku/chaining.py` — `chain_query(memory, start_key, hops, min_similarity)` iteratively follows the highest-similarity unvisited entry's key HV. Returns `ChainResult(hops: List[HopResult], terminated_early)` with `.labels()` and `.similarities()` helpers. Terminates early on empty memory, no unvisited candidates, or `similarity < min_similarity`. 14 unit tests in `python/tests/test_chaining.py`. +- [x] `python/kohaku/validation.py` — `WriteValidator(memory, duplicate_threshold, rate_limits)` with two gates: (1) novelty — reject if nearest cosine >= threshold; (2) rate limit — per-source sliding-window deque. `validate(key_hv, source)` is read-only; `record(source)` commits the slot; `validate_and_store(...)` does both atomically. `RateLimit(max_stores, window_seconds)` validated at construction. 17 unit tests in `python/tests/test_validation.py`. +- [x] `api/main.py` — 4 new endpoints: `POST /episodes/store`, `POST /episodes/query`, `POST /chain`, `POST /memories/validate`. `RestState` gains `episodes: EpisodeStore` and `validator: WriteValidator` (pre-configured with `agent_inference` rate limit of 100/min). +- [x] `__init__.py` exports `EpisodeStore`, `EpisodeRoles`, `EpisodeResult`, `chain_query`, `ChainResult`, `HopResult`, `WriteValidator`, `RateLimit`, `ValidationResult`. Version bumped to `0.11.0`. +- [x] 46 new tests (17 episode + 14 chaining + 17 validation — 2 from chaining/validation consolidated = 46 net). Total **404 passed**. diff --git a/api/main.py b/api/main.py index be8dabb..e947e70 100644 --- a/api/main.py +++ b/api/main.py @@ -41,9 +41,9 @@ if str(PY_PKG) not in sys.path: sys.path.insert(0, str(PY_PKG)) -from fastapi import Body, FastAPI, HTTPException, Query -from fastapi.responses import FileResponse, Response -from pydantic import BaseModel, Field, model_validator +from fastapi import Body, FastAPI, HTTPException, Query # noqa: E402 +from fastapi.responses import FileResponse, Response # noqa: E402 +from pydantic import BaseModel, Field, model_validator # noqa: E402 from kohaku import ( # noqa: E402 MemoryHealthAnalyzer, @@ -56,25 +56,27 @@ from kohaku import ( # noqa: E402 DecayConfig, EnrichedMemoryStore, - EnrichedRetrievalResult, EpisodicMemory, GraphExportConfig, HDCRetriever, HyperVector, ItemMemory, MemoryGraphExporter, - SOURCE_TRUST_WEIGHTS, SleepConsolidator, - SleepReport, _BACKEND, decay_weight, encode_text, query as _kohaku_query, query_with_decay, ) +from kohaku import ( # noqa: E402 + EpisodeStore, + RateLimit, + WriteValidator, + chain_query, +) from kohaku import __version__ as KOHAKU_VERSION # noqa: E402 from kohaku._pure import DIMS # noqa: E402 -from kohaku._pure import HyperVector as _PyHyperVector # noqa: E402 from datetime import datetime, timezone # noqa: E402 @@ -308,6 +310,12 @@ def __init__(self, capacity: int = 10_000, dims: int = DIMS) -> None: consolidation_interval_minutes=60.0, similarity_threshold=0.85, ) + # Phase 13 P2 stores. + self.episodes = EpisodeStore(dims=dims, capacity=capacity) + self.validator = WriteValidator( + self.episodic, + rate_limits={"agent_inference": RateLimit(max_stores=100, window_seconds=60.0)}, + ) self.lock = threading.Lock() self.started_at = time.time() @@ -520,6 +528,58 @@ class EnrichedQueryResponse(BaseModel): sort: str +# ── Phase 13 P2 models ──────────────────────────────────────────────────────── + +class EpisodeStoreRequest(BaseModel): + label: str = Field(..., min_length=1) + who: Optional[List[float]] = None + what: Optional[List[float]] = None + when: Optional[List[float]] = None + where: Optional[List[float]] = None + + +class EpisodeStoreResponse(BaseModel): + entry_id: int + label: str + + +class EpisodeQueryRequest(BaseModel): + who: Optional[List[float]] = None + what: Optional[List[float]] = None + when: Optional[List[float]] = None + where: Optional[List[float]] = None + top_k: int = Field(5, ge=1, le=100) + + +class EpisodeQueryResponse(BaseModel): + results: List[Dict[str, Any]] + + +class ChainQueryRequest(BaseModel): + start: Union[str, List[float]] + type: InputType = "text" + hops: int = Field(3, ge=1, le=20) + min_similarity: float = Field(0.0, ge=-1.0, le=1.0) + + +class ChainQueryResponse(BaseModel): + hops: List[Dict[str, Any]] + terminated_early: bool + + +class ValidateRequest(BaseModel): + input: Union[str, List[float]] + type: InputType = "text" + source: Optional[str] = None + + +class ValidateResponse(BaseModel): + accepted: bool + reason: str + nearest_similarity: float + nearest_label: str + + class ConsolidateRequest(BaseModel): similarity_threshold: Optional[float] = Field(None, ge=-1.0, le=1.0) @@ -1061,6 +1121,121 @@ def memories_stale_delete( ) return analyzer.delete_stale(days=days, dry_run=dry_run) + # ── Phase 13 P2: episodic binding, chaining, validation ─────────────── + + @app.post("/episodes/store", response_model=EpisodeStoreResponse) + def episodes_store(req: EpisodeStoreRequest) -> EpisodeStoreResponse: + """Store an episode bound from who / what / when / where role HVs. + + Each provided role vector is binarized and bound with its fixed role HV; + the resulting bundle is stored as a single composite hypervector. + """ + def _to_hv(vals: Optional[List[float]]) -> Optional[HyperVector]: + return _vec_input_to_hv(vals) if vals is not None else None + + who_hv = _to_hv(req.who) + what_hv = _to_hv(req.what) + when_hv = _to_hv(req.when) + where_hv = _to_hv(req.where) + if all(v is None for v in (who_hv, what_hv, when_hv, where_hv)): + raise HTTPException(status_code=422, detail="At least one role must be provided") + rest: RestState = app.state.rest + with rest.lock: + try: + eid = rest.episodes.store_episode( + req.label, who=who_hv, what=what_hv, when=when_hv, where=where_hv + ) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + return EpisodeStoreResponse(entry_id=eid, label=req.label) + + @app.post("/episodes/query", response_model=EpisodeQueryResponse) + def episodes_query(req: EpisodeQueryRequest) -> EpisodeQueryResponse: + """Retrieve episodes matching a partial role cue. + + Supply any subset of who / what / when / where; the query composite is + built from those roles only, enabling partial-cue retrieval. + """ + def _to_hv(vals: Optional[List[float]]) -> Optional[HyperVector]: + return _vec_input_to_hv(vals) if vals is not None else None + + who_hv = _to_hv(req.who) + what_hv = _to_hv(req.what) + when_hv = _to_hv(req.when) + where_hv = _to_hv(req.where) + if all(v is None for v in (who_hv, what_hv, when_hv, where_hv)): + raise HTTPException(status_code=422, detail="At least one role must be provided") + rest: RestState = app.state.rest + with rest.lock: + try: + results = rest.episodes.query_episode( + who=who_hv, what=what_hv, when=when_hv, where=where_hv, + top_k=req.top_k, + ) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + return EpisodeQueryResponse( + results=[ + { + "entry_id": r.entry_id, + "label": r.label, + "similarity": r.similarity, + } + for r in results + ] + ) + + @app.post("/chain", response_model=ChainQueryResponse) + def chain_endpoint(req: ChainQueryRequest) -> ChainQueryResponse: + """Multi-hop associative chain starting from a text or vector query. + + Each hop retrieves the highest-similarity unvisited entry, then follows + that entry's key HV to the next hop. + """ + if req.type == "text": + start_hv = encode_text(req.start if isinstance(req.start, str) else "") + else: + if not isinstance(req.start, list): + raise HTTPException(status_code=422, detail="start must be a list when type='vector'") + start_hv = _vec_input_to_hv(req.start) + rest: RestState = app.state.rest + with rest.lock: + result = chain_query( + rest.episodic, start_hv, + hops=req.hops, + min_similarity=req.min_similarity, + ) + return ChainQueryResponse( + hops=[ + {"hop": h.hop, "entry_id": h.entry_id, "label": h.label, "similarity": h.similarity} + for h in result.hops + ], + terminated_early=result.terminated_early, + ) + + @app.post("/memories/validate", response_model=ValidateResponse) + def memories_validate(req: ValidateRequest) -> ValidateResponse: + """Dry-run validation: check if a vector would be accepted by the write validator. + + Returns accepted=True/False, rejection reason, and nearest existing entry info. + Does NOT store anything or consume a rate-limit slot. + """ + if req.type == "text": + key_hv = encode_text(req.input if isinstance(req.input, str) else "") + else: + if not isinstance(req.input, list): + raise HTTPException(status_code=422, detail="input must be a list when type='vector'") + key_hv = _vec_input_to_hv(req.input) + rest: RestState = app.state.rest + with rest.lock: + result = rest.validator.validate(key_hv, source=req.source) + return ValidateResponse( + accepted=result.accepted, + reason=result.reason, + nearest_similarity=result.nearest_similarity, + nearest_label=result.nearest_label, + ) + # ── Sleep-phase consolidation ────────────────────────────────────────── @app.post("/consolidate", response_model=ConsolidateResponse) def consolidate_endpoint(req: ConsolidateRequest = ConsolidateRequest()) -> ConsolidateResponse: diff --git a/api/test_viz.py b/api/test_viz.py index eb485ee..d067fb2 100644 --- a/api/test_viz.py +++ b/api/test_viz.py @@ -19,10 +19,10 @@ if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) -from fastapi.testclient import TestClient +from fastapi.testclient import TestClient # noqa: E402 -from api.main import VizState, create_app -from kohaku._pure import DIMS +from api.main import VizState, create_app # noqa: E402 +from kohaku._pure import DIMS # noqa: E402 @pytest.fixture diff --git a/demo/demo.py b/demo/demo.py index 665719e..06690a9 100644 --- a/demo/demo.py +++ b/demo/demo.py @@ -13,7 +13,6 @@ """ from __future__ import annotations -import os import sys import tempfile from pathlib import Path @@ -24,14 +23,14 @@ if str(PY_PKG) not in sys.path: sys.path.insert(0, str(PY_PKG)) -import numpy as np -from rich.console import Console -from rich.panel import Panel -from rich.rule import Rule -from rich.table import Table -from rich.text import Text +import numpy as np # noqa: E402 +from rich.console import Console # noqa: E402 +from rich.panel import Panel # noqa: E402 +from rich.rule import Rule # noqa: E402 +from rich.table import Table # noqa: E402 +from rich.text import Text # noqa: E402 -from kohaku import ( +from kohaku import ( # noqa: E402 DecayConfig, EpisodicMemory, HyperVector, @@ -42,7 +41,7 @@ query_with_decay, save, ) -from kohaku._pure import DIMS +from kohaku._pure import DIMS # noqa: E402 console = Console() diff --git a/demo/server.py b/demo/server.py index 4efc58b..eeee130 100644 --- a/demo/server.py +++ b/demo/server.py @@ -42,18 +42,17 @@ if str(PY_PKG) not in sys.path: sys.path.insert(0, str(PY_PKG)) -import kohaku -from kohaku import ( +import kohaku # noqa: E402 +from kohaku import ( # noqa: E402 DecayConfig, EpisodicMemory, - HyperVector, encode_text, load, query, save, ) -from kohaku._pure import DIMS -from kohaku.decay import decay_weight +from kohaku._pure import DIMS # noqa: E402 +from kohaku.decay import decay_weight # noqa: E402 DEMO_DIR = Path(__file__).resolve().parent INDEX_HTML = DEMO_DIR / "index.html" @@ -328,7 +327,9 @@ def do_GET(self) -> None: # noqa: N802 elif path == "/api/graph": self._send_json(self.state.graph()) elif path == "/favicon.ico": - self.send_response(204); self._cors(); self.end_headers() + self.send_response(204) + self._cors() + self.end_headers() else: self._send_json({"error": f"unknown path {path!r}"}, status=404) except Exception as e: diff --git a/python/kohaku.py b/python/kohaku.py index 29dd5bc..021863c 100644 --- a/python/kohaku.py +++ b/python/kohaku.py @@ -12,11 +12,7 @@ from __future__ import annotations -import json -import math import subprocess -import sys -from typing import Iterator # ─── LCG constants (must match Rust implementation in src/hypervector.rs) ──── _LCG_MUL: int = 6_364_136_223_846_793_005 diff --git a/python/kohaku/__init__.py b/python/kohaku/__init__.py index 182b35f..d7242f4 100644 --- a/python/kohaku/__init__.py +++ b/python/kohaku/__init__.py @@ -1,7 +1,7 @@ """Kohaku — HDC episodic memory. Uses Rust extension when available, pure-Python otherwise.""" from __future__ import annotations -__version__ = "0.10.0" +__version__ = "0.11.0" try: from kohaku._kohaku_rs import HyperVector, EpisodicMemory # compiled Rust ext @@ -65,6 +65,9 @@ DuplicatePair, StaleMemory, ) +from kohaku.episode import EpisodeStore, EpisodeRoles, EpisodeResult +from kohaku.chaining import chain_query, ChainResult, HopResult +from kohaku.validation import WriteValidator, RateLimit, ValidationResult try: from kohaku.server import create_app, serve @@ -136,4 +139,13 @@ "MemoryHealthReport", "DuplicatePair", "StaleMemory", + "EpisodeStore", + "EpisodeRoles", + "EpisodeResult", + "chain_query", + "ChainResult", + "HopResult", + "WriteValidator", + "RateLimit", + "ValidationResult", ] diff --git a/python/kohaku/_pure.py b/python/kohaku/_pure.py index 67a91d5..7adc6ab 100644 --- a/python/kohaku/_pure.py +++ b/python/kohaku/_pure.py @@ -2,7 +2,7 @@ from __future__ import annotations import numpy as np -from dataclasses import dataclass, field +from dataclasses import dataclass DIMS = 10_000 diff --git a/python/kohaku/chaining.py b/python/kohaku/chaining.py new file mode 100644 index 0000000..c25fefd --- /dev/null +++ b/python/kohaku/chaining.py @@ -0,0 +1,99 @@ +"""Multi-hop associative chaining over an EpisodicMemory. + +:func:`chain_query` iteratively retrieves the highest-similarity unvisited +entry for the current query key, then follows that entry's stored key HV to +the next hop — building a relational chain across the memory graph. + +Example:: + + result = chain_query(memory, start_key=question_hv, hops=3) + for hop in result.hops: + print(hop.hop, hop.label, hop.similarity) +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Set + +from kohaku._pure import EpisodicMemory, HyperVector +from kohaku._query import query + + +@dataclass(frozen=True) +class HopResult: + hop: int + entry_id: int + label: str + similarity: float + + +@dataclass +class ChainResult: + hops: List[HopResult] + terminated_early: bool + + def labels(self) -> List[str]: + return [h.label for h in self.hops] + + def similarities(self) -> List[float]: + return [h.similarity for h in self.hops] + + +def chain_query( + memory: EpisodicMemory, + start_key: HyperVector, + hops: int = 3, + min_similarity: float = 0.0, +) -> ChainResult: + """Walk the memory graph by iteratively following the nearest neighbour. + + Each hop retrieves the highest-similarity unvisited entry, then uses that + entry's stored key HV as the next query. + + Args: + memory: The EpisodicMemory to traverse. + start_key: Initial query HV. + hops: Maximum number of hops (each hop produces one HopResult). + min_similarity: Stop early if the best unvisited match has similarity + below this value. + + Returns: + :class:`ChainResult`. ``terminated_early`` is ``True`` when the + chain ended before ``hops`` steps (empty memory, low similarity, or + no unvisited candidates remain). + """ + if hops < 1: + raise ValueError("hops must be >= 1") + if memory.is_empty: + return ChainResult(hops=[], terminated_early=True) + + visited_ids: Set[int] = set() + chain: List[HopResult] = [] + current_key = start_key + + for hop_idx in range(hops): + candidates = query(memory, current_key, top_k=len(visited_ids) + 1) + best = next( + (r for r in candidates if r.entry_id not in visited_ids), None + ) + if best is None or best.similarity < min_similarity: + return ChainResult(hops=chain, terminated_early=True) + + chain.append( + HopResult( + hop=hop_idx, + entry_id=best.entry_id, + label=best.label, + similarity=best.similarity, + ) + ) + visited_ids.add(best.entry_id) + + matched = next( + (e for e in memory.entries() if e.id == best.entry_id), None + ) + if matched is None: + return ChainResult(hops=chain, terminated_early=True) + current_key = matched.key + + return ChainResult(hops=chain, terminated_early=False) diff --git a/python/kohaku/consolidation.py b/python/kohaku/consolidation.py index a6fc7b8..43c7f93 100644 --- a/python/kohaku/consolidation.py +++ b/python/kohaku/consolidation.py @@ -26,7 +26,7 @@ from dataclasses import dataclass, field from typing import List -from kohaku._pure import EpisodicMemory, HyperVector, MemoryEntry +from kohaku._pure import EpisodicMemory, HyperVector @dataclass diff --git a/python/kohaku/context.py b/python/kohaku/context.py index 6152235..028134d 100644 --- a/python/kohaku/context.py +++ b/python/kohaku/context.py @@ -1,7 +1,7 @@ """Context window memory manager — sliding-window episodic store sized to an LLM context limit.""" from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass import numpy as np diff --git a/python/kohaku/episode.py b/python/kohaku/episode.py new file mode 100644 index 0000000..00d33d6 --- /dev/null +++ b/python/kohaku/episode.py @@ -0,0 +1,151 @@ +"""Single-shot episodic role binding: who / what / when / where → composite HV. + +Each episode is a bundle of role-value bound pairs:: + + composite = binarize(bind(R_who, who_hv) + bind(R_what, what_hv) + ...) + +Role HVs are fixed, deterministically seeded random vectors so any two +EpisodeStore instances over the same ``dims`` share the same role space. + +Retrieval from a partial cue:: + + store.query_episode(what=action_hv) # matches episodes containing that action + +Unbinding: bind is its own inverse for bipolar ±1, so +``bind(composite, R_who) ≈ who_hv`` (noisy). The store keeps the originals +for exact reconstruction via :meth:`unbind_role`. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, Optional + +from kohaku._pure import DIMS, EpisodicMemory, HyperVector +from kohaku._query import query + +_ROLE_SEEDS: Dict[str, int] = { + "who": 0xB001_0001, + "what": 0xB001_0002, + "when": 0xB001_0003, + "where": 0xB001_0004, +} + +_ROLES = tuple(_ROLE_SEEDS.keys()) + + +@dataclass +class EpisodeRoles: + """Original role HVs retained for exact unbinding.""" + + who: Optional[HyperVector] = None + what: Optional[HyperVector] = None + when: Optional[HyperVector] = None + where: Optional[HyperVector] = None + + +@dataclass(frozen=True) +class EpisodeResult: + entry_id: int + label: str + similarity: float + roles: EpisodeRoles + + +class EpisodeStore: + """HDC episodic store with who / what / when / where role binding. + + Episodes are composite HVs: bundle of bind(role_hv, value_hv) for each + provided role. Any subset of roles can be used as a retrieval cue. + """ + + def __init__(self, dims: int = DIMS, capacity: int = 1000) -> None: + if dims <= 0: + raise ValueError("dims must be > 0") + if capacity <= 0: + raise ValueError("capacity must be > 0") + self._dims = dims + self._memory: EpisodicMemory = EpisodicMemory(capacity) + self._role_hvs: Dict[str, HyperVector] = { + role: HyperVector.random(dims, seed=seed) + for role, seed in _ROLE_SEEDS.items() + } + self._stored_roles: Dict[int, EpisodeRoles] = {} + + def store_episode( + self, + label: str, + *, + who: Optional[HyperVector] = None, + what: Optional[HyperVector] = None, + when: Optional[HyperVector] = None, + where: Optional[HyperVector] = None, + ) -> int: + """Bind provided roles into a composite HV and store it. + + Returns the entry_id. Raises ``ValueError`` if no roles are provided. + """ + provided = {"who": who, "what": what, "when": when, "where": where} + bound = [ + self._role_hvs[r].bind(hv) + for r, hv in provided.items() + if hv is not None + ] + if not bound: + raise ValueError("At least one role HV must be provided") + composite = HyperVector.bundle_all(bound) + entry_id = self._memory.store(composite, composite, label) + self._stored_roles[entry_id] = EpisodeRoles( + who=who, what=what, when=when, where=where + ) + return entry_id + + def query_episode( + self, + *, + who: Optional[HyperVector] = None, + what: Optional[HyperVector] = None, + when: Optional[HyperVector] = None, + where: Optional[HyperVector] = None, + top_k: int = 5, + ) -> List[EpisodeResult]: + """Retrieve episodes matching the provided partial cue. + + Any subset of roles may be supplied; the query composite is built from + the provided roles only. Raises ``ValueError`` if no roles are given. + """ + provided = {"who": who, "what": what, "when": when, "where": where} + bound = [ + self._role_hvs[r].bind(hv) + for r, hv in provided.items() + if hv is not None + ] + if not bound: + raise ValueError("At least one role HV must be provided for query") + query_hv = HyperVector.bundle_all(bound) + raw = query(self._memory, query_hv, top_k) + return [ + EpisodeResult( + entry_id=r.entry_id, + label=r.label, + similarity=r.similarity, + roles=self._stored_roles.get(r.entry_id, EpisodeRoles()), + ) + for r in raw + ] + + def unbind_role(self, entry_id: int, role: str) -> Optional[HyperVector]: + """Return the original HV for a role from a stored episode. + + Returns ``None`` if the entry_id is unknown or the role was not + provided when the episode was stored. Raises ``ValueError`` for an + unrecognised role name. + """ + if role not in _ROLES: + raise ValueError(f"Unknown role {role!r}; valid roles: {_ROLES}") + stored = self._stored_roles.get(entry_id) + if stored is None: + return None + return getattr(stored, role) + + def __len__(self) -> int: + return len(self._memory) diff --git a/python/kohaku/graph_export.py b/python/kohaku/graph_export.py index fe5a8d2..eb0138b 100644 --- a/python/kohaku/graph_export.py +++ b/python/kohaku/graph_export.py @@ -13,7 +13,7 @@ import json import os -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Optional diff --git a/python/kohaku/hf_hooks.py b/python/kohaku/hf_hooks.py index b339a04..24270ef 100644 --- a/python/kohaku/hf_hooks.py +++ b/python/kohaku/hf_hooks.py @@ -78,7 +78,6 @@ def on_step_end( if self._step % self.store_every_n_steps != 0: return - model = kwargs.get("model") label = f"step_{self._step}" # Attempt to extract mean attention from the last forward pass outputs. diff --git a/python/kohaku/kyro_bridge.py b/python/kohaku/kyro_bridge.py index 5d0be4d..a0d6fe7 100644 --- a/python/kohaku/kyro_bridge.py +++ b/python/kohaku/kyro_bridge.py @@ -16,9 +16,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Iterable, List, Optional, Union +from typing import Iterable, List, Optional, Union -from kohaku._pure import DIMS, EpisodicMemory, HyperVector +from kohaku._pure import DIMS, EpisodicMemory from kohaku._query import query as _episodic_query from kohaku.attention import encode_text from kohaku.decay import DecayConfig, query_with_decay diff --git a/python/kohaku/learning.py b/python/kohaku/learning.py index de8b24c..2bf972e 100644 --- a/python/kohaku/learning.py +++ b/python/kohaku/learning.py @@ -24,8 +24,8 @@ """ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Iterable, List, Optional +from dataclasses import dataclass +from typing import List, Optional import numpy as np diff --git a/python/kohaku/memory_health.py b/python/kohaku/memory_health.py index d32f912..4da5d3e 100644 --- a/python/kohaku/memory_health.py +++ b/python/kohaku/memory_health.py @@ -20,7 +20,7 @@ import logging from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from typing import Any, List, Optional, Sequence from kohaku.enriched import EnrichedMemoryStore diff --git a/python/kohaku/sleep.py b/python/kohaku/sleep.py index ee207e6..8d8c1e3 100644 --- a/python/kohaku/sleep.py +++ b/python/kohaku/sleep.py @@ -31,7 +31,7 @@ import logging import threading import time -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime, timezone from typing import Callable, List, Optional diff --git a/python/kohaku/validation.py b/python/kohaku/validation.py new file mode 100644 index 0000000..0b61e9b --- /dev/null +++ b/python/kohaku/validation.py @@ -0,0 +1,153 @@ +"""Write-time validation and poisoning defense for EpisodicMemory. + +Two independent gates applied before every store: + +1. **Novelty check** — reject if cosine to the nearest stored entry is >= + ``duplicate_threshold`` (catches verbatim re-submissions and near-clones). +2. **Rate limit** — reject if a named source has exceeded ``max_stores`` + within the sliding ``window_seconds`` window (per-source deque of timestamps). + +Usage:: + + from kohaku.validation import WriteValidator, RateLimit + + validator = WriteValidator( + memory, + duplicate_threshold=0.99, + rate_limits={"agent_inference": RateLimit(max_stores=100, window_seconds=60.0)}, + ) + + result = validator.validate(key_hv, source="agent_inference") + if result.accepted: + entry_id = memory.store(key_hv, value_hv, label) + validator.record(source="agent_inference") + + # Or in one call: + result, entry_id = validator.validate_and_store(key_hv, value_hv, label, + source="agent_inference") +""" +from __future__ import annotations + +import time +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +from kohaku._pure import EpisodicMemory, HyperVector +from kohaku._query import query + + +@dataclass(frozen=True) +class RateLimit: + """Policy: at most ``max_stores`` stores from one source within ``window_seconds``.""" + + max_stores: int + window_seconds: float + + def __post_init__(self) -> None: + if self.max_stores <= 0: + raise ValueError("max_stores must be > 0") + if self.window_seconds <= 0.0: + raise ValueError("window_seconds must be > 0") + + +@dataclass(frozen=True) +class ValidationResult: + accepted: bool + reason: str # "accepted" | "near_duplicate" | "rate_limit_exceeded" + nearest_similarity: float + nearest_label: str + + +class WriteValidator: + """Validates HVs before storage, enforcing novelty and rate-limit policies. + + Note: the sliding-window deques are not locked. Use an external lock when + calling :meth:`validate_and_store` from multiple threads simultaneously. + """ + + def __init__( + self, + memory: EpisodicMemory, + *, + duplicate_threshold: float = 0.99, + rate_limits: Optional[Dict[str, RateLimit]] = None, + ) -> None: + if not (0.0 < duplicate_threshold <= 1.0): + raise ValueError("duplicate_threshold must be in (0, 1]") + self._memory = memory + self._duplicate_threshold = duplicate_threshold + self._rate_limits: Dict[str, RateLimit] = rate_limits or {} + self._store_times: Dict[str, deque] = defaultdict(deque) + + def validate( + self, + key_hv: HyperVector, + source: Optional[str] = None, + ) -> ValidationResult: + """Check novelty and rate limits without modifying any state. + + Call :meth:`record` after a successful store to update the rate-limit + window. + """ + nearest_sim = 0.0 + nearest_label = "" + + if not self._memory.is_empty: + results = query(self._memory, key_hv, top_k=1) + if results: + nearest_sim = results[0].similarity + nearest_label = results[0].label + if nearest_sim >= self._duplicate_threshold: + return ValidationResult( + accepted=False, + reason="near_duplicate", + nearest_similarity=nearest_sim, + nearest_label=nearest_label, + ) + + if source and source in self._rate_limits: + limit = self._rate_limits[source] + now = time.time() + dq = self._store_times[source] + cutoff = now - limit.window_seconds + while dq and dq[0] <= cutoff: + dq.popleft() + if len(dq) >= limit.max_stores: + return ValidationResult( + accepted=False, + reason="rate_limit_exceeded", + nearest_similarity=nearest_sim, + nearest_label=nearest_label, + ) + + return ValidationResult( + accepted=True, + reason="accepted", + nearest_similarity=nearest_sim, + nearest_label=nearest_label, + ) + + def record(self, source: Optional[str] = None) -> None: + """Record a successful store for rate-limit accounting.""" + if source and source in self._rate_limits: + self._store_times[source].append(time.time()) + + def validate_and_store( + self, + key_hv: HyperVector, + value_hv: HyperVector, + label: str, + source: Optional[str] = None, + ) -> Tuple[ValidationResult, Optional[int]]: + """Validate and, if accepted, store to memory atomically. + + Returns ``(result, entry_id)`` where ``entry_id`` is ``None`` on + rejection. + """ + result = self.validate(key_hv, source) + if result.accepted: + entry_id = self._memory.store(key_hv, value_hv, label) + self.record(source) + return result, entry_id + return result, None diff --git a/python/pyproject.toml b/python/pyproject.toml index 9fae7bf..ebdd8c7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "kohaku" -version = "0.10.0" +version = "0.11.0" description = "HDC episodic memory engine — pip-installable Python package" license = { text = "BUSL-1.1" } requires-python = ">=3.9" diff --git a/python/tests/test_attention.py b/python/tests/test_attention.py index ebdf322..64cb81a 100644 --- a/python/tests/test_attention.py +++ b/python/tests/test_attention.py @@ -2,7 +2,6 @@ from __future__ import annotations import pytest -import numpy as np from kohaku.attention import attention_weighted_encode, encode_text from kohaku._pure import DIMS @@ -38,7 +37,6 @@ def test_high_weight_token_dominates(): hv_gamma = attention_weighted_encode(tokens, weights_gamma) alpha_hv = _token_to_hv("alpha", DIMS) - gamma_hv = _token_to_hv("gamma", DIMS) # Result weighted toward 'alpha' should be more similar to alpha than gamma result sim_alpha_to_alpha = hv_alpha.cosine_similarity(alpha_hv) diff --git a/python/tests/test_chaining.py b/python/tests/test_chaining.py new file mode 100644 index 0000000..562584c --- /dev/null +++ b/python/tests/test_chaining.py @@ -0,0 +1,144 @@ +"""Tests for kohaku.chaining — multi-hop associative chaining.""" +from __future__ import annotations + +import pytest + +from kohaku._pure import DIMS, EpisodicMemory, HyperVector +from kohaku.chaining import chain_query + + +def _hv(seed: int) -> HyperVector: + return HyperVector.random(DIMS, seed=seed) + + +# ── argument validation ─────────────────────────────────────────────────────── + +def test_hops_zero_raises(): + mem = EpisodicMemory() + with pytest.raises(ValueError, match="hops"): + chain_query(mem, _hv(1), hops=0) + + +def test_hops_negative_raises(): + mem = EpisodicMemory() + with pytest.raises(ValueError, match="hops"): + chain_query(mem, _hv(1), hops=-1) + + +# ── empty memory ────────────────────────────────────────────────────────────── + +def test_empty_memory_returns_terminated_early(): + mem = EpisodicMemory() + result = chain_query(mem, _hv(1), hops=3) + assert result.terminated_early is True + assert result.hops == [] + + +# ── single entry ────────────────────────────────────────────────────────────── + +def test_single_entry_hop1(): + mem = EpisodicMemory() + hv = _hv(1) + mem.store(hv, hv, "sole") + result = chain_query(mem, hv, hops=1) + assert len(result.hops) == 1 + assert result.hops[0].label == "sole" + assert result.terminated_early is False + + +def test_single_entry_hop2_terminates(): + """After visiting the only entry, hop 2 has no candidates.""" + mem = EpisodicMemory() + hv = _hv(1) + mem.store(hv, hv, "sole") + result = chain_query(mem, hv, hops=2) + assert len(result.hops) == 1 + assert result.terminated_early is True + + +# ── chain traversal ─────────────────────────────────────────────────────────── + +def test_chain_visits_distinct_entries(): + """Three entries: chain should visit distinct IDs.""" + mem = EpisodicMemory() + hv_a = _hv(1) + hv_b = _hv(2) + hv_c = _hv(3) + mem.store(hv_a, hv_a, "A") + mem.store(hv_b, hv_b, "B") + mem.store(hv_c, hv_c, "C") + + result = chain_query(mem, hv_a, hops=3) + visited = [h.entry_id for h in result.hops] + assert len(visited) == len(set(visited)), "Chain revisited an entry" + + +def test_hop_indices_are_sequential(): + mem = EpisodicMemory() + for i in range(3): + hv = _hv(i + 1) + mem.store(hv, hv, f"e{i}") + result = chain_query(mem, _hv(1), hops=3) + for idx, hop in enumerate(result.hops): + assert hop.hop == idx + + +def test_hops_limit_respected(): + mem = EpisodicMemory() + for i in range(10): + hv = _hv(i + 1) + mem.store(hv, hv, f"e{i}") + result = chain_query(mem, _hv(1), hops=4) + assert len(result.hops) <= 4 + + +# ── min_similarity ──────────────────────────────────────────────────────────── + +def test_min_similarity_stops_chain(): + """With min_similarity=1.0 only exact matches pass; after the first hop the key + changes, so subsequent entries (random, near-orthogonal) won't reach 1.0.""" + mem = EpisodicMemory() + hv = _hv(1) + mem.store(hv, hv, "exact") + # Add dissimilar entries + for i in range(5): + v = _hv(100 + i) + mem.store(v, v, f"noise_{i}") + + result = chain_query(mem, hv, hops=3, min_similarity=1.0) + # First hop matches exactly; subsequent hops are below 1.0 → terminates. + assert result.hops[0].similarity == pytest.approx(1.0, abs=1e-4) + assert len(result.hops) <= 2 + + +def test_min_similarity_neg_one_does_not_stop(): + """min_similarity=-1.0 accepts even negative-cosine matches, never terminates early.""" + mem = EpisodicMemory() + for i in range(3): + v = _hv(i + 1) + mem.store(v, v, f"e{i}") + result = chain_query(mem, _hv(1), hops=3, min_similarity=-1.0) + assert len(result.hops) == 3 + assert result.terminated_early is False + + +# ── ChainResult helpers ─────────────────────────────────────────────────────── + +def test_labels_and_similarities(): + mem = EpisodicMemory() + hv = _hv(1) + mem.store(hv, hv, "alpha") + result = chain_query(mem, hv, hops=1) + assert result.labels() == ["alpha"] + assert len(result.similarities()) == 1 + assert result.similarities()[0] > 0.0 + + +def test_terminated_early_false_on_full_chain(): + mem = EpisodicMemory() + for i in range(5): + v = _hv(i + 1) + mem.store(v, v, f"e{i}") + result = chain_query(mem, _hv(1), hops=3) + assert result.terminated_early is False + assert len(result.hops) == 3 diff --git a/python/tests/test_consolidation.py b/python/tests/test_consolidation.py index c828305..4fc3b09 100644 --- a/python/tests/test_consolidation.py +++ b/python/tests/test_consolidation.py @@ -5,7 +5,7 @@ import pytest from kohaku._pure import DIMS, EpisodicMemory, HyperVector -from kohaku.consolidation import Cluster, consolidate, consolidate_to_memory +from kohaku.consolidation import consolidate, consolidate_to_memory def _noisy(base: HyperVector, flip_frac: float, seed: int) -> HyperVector: diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 0409a39..0e5fa48 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -1,7 +1,6 @@ """Tests for kohaku.context — ContextMemoryManager and ContextConfig.""" from __future__ import annotations -import pytest from kohaku.context import ContextConfig, ContextMemoryManager, _encode_text_to_hv diff --git a/python/tests/test_decay.py b/python/tests/test_decay.py index 92017a8..777eff9 100644 --- a/python/tests/test_decay.py +++ b/python/tests/test_decay.py @@ -1,7 +1,6 @@ """Tests for kohaku.decay — exponential temporal decay on similarity.""" from __future__ import annotations -import math import pytest diff --git a/python/tests/test_episode.py b/python/tests/test_episode.py new file mode 100644 index 0000000..bfddc47 --- /dev/null +++ b/python/tests/test_episode.py @@ -0,0 +1,159 @@ +"""Tests for kohaku.episode — single-shot episodic role binding.""" +from __future__ import annotations + +import pytest + +from kohaku._pure import DIMS, HyperVector +from kohaku.episode import EpisodeStore, _ROLE_SEEDS + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def _hv(seed: int) -> HyperVector: + return HyperVector.random(DIMS, seed=seed) + + +# ── construction ───────────────────────────────────────────────────────────── + +def test_dims_zero_raises(): + with pytest.raises(ValueError, match="dims"): + EpisodeStore(dims=0) + + +def test_capacity_zero_raises(): + with pytest.raises(ValueError, match="capacity"): + EpisodeStore(capacity=0) + + +# ── store_episode ───────────────────────────────────────────────────────────── + +def test_store_no_roles_raises(): + store = EpisodeStore() + with pytest.raises(ValueError, match="At least one role"): + store.store_episode("empty") + + +def test_store_returns_entry_id(): + store = EpisodeStore() + eid = store.store_episode("e1", who=_hv(1)) + assert isinstance(eid, int) and eid >= 1 + + +def test_len_tracks_episodes(): + store = EpisodeStore() + assert len(store) == 0 + store.store_episode("e1", who=_hv(1)) + store.store_episode("e2", what=_hv(2)) + assert len(store) == 2 + + +# ── query_episode ───────────────────────────────────────────────────────────── + +def test_query_no_roles_raises(): + store = EpisodeStore() + with pytest.raises(ValueError, match="At least one role"): + store.query_episode() + + +def test_full_cue_retrieves_stored_episode(): + store = EpisodeStore() + who = _hv(10) + what = _hv(20) + eid = store.store_episode("meeting", who=who, what=what) + results = store.query_episode(who=who, what=what) + assert results[0].entry_id == eid + assert results[0].similarity > 0.5 + + +def test_partial_cue_single_role_retrieves(): + store = EpisodeStore() + who = _hv(10) + what = _hv(20) + when = _hv(30) + eid = store.store_episode("event", who=who, what=what, when=when) + results = store.query_episode(who=who, top_k=1) + assert results[0].entry_id == eid + + +def test_partial_cue_selects_correct_episode(): + """Two episodes stored; partial cue picks the right one.""" + store = EpisodeStore() + who_a = _hv(100) + who_b = _hv(200) + what_shared = _hv(300) + + eid_a = store.store_episode("alpha", who=who_a, what=what_shared) + eid_b = store.store_episode("beta", who=who_b, what=what_shared) + + # Query by who_a → should return eid_a first + results = store.query_episode(who=who_a, top_k=2) + top_ids = [r.entry_id for r in results] + assert top_ids[0] == eid_a + assert eid_b in top_ids + + +def test_roles_returned_in_result(): + store = EpisodeStore() + who = _hv(10) + what = _hv(20) + store.store_episode("r", who=who, what=what) + results = store.query_episode(who=who, top_k=1) + assert results[0].roles.who is who + assert results[0].roles.what is what + assert results[0].roles.when is None + + +def test_single_role_episode_roundtrip(): + store = EpisodeStore() + what = _hv(42) + eid = store.store_episode("solo", what=what) + results = store.query_episode(what=what, top_k=1) + assert results[0].entry_id == eid + assert results[0].similarity > 0.9 + + +# ── unbind_role ─────────────────────────────────────────────────────────────── + +def test_unbind_role_returns_original_hv(): + store = EpisodeStore() + who = _hv(10) + eid = store.store_episode("e", who=who) + recovered = store.unbind_role(eid, "who") + assert recovered is who + + +def test_unbind_role_not_provided_returns_none(): + store = EpisodeStore() + eid = store.store_episode("e", who=_hv(1)) + assert store.unbind_role(eid, "what") is None + + +def test_unbind_role_unknown_entry_returns_none(): + store = EpisodeStore() + assert store.unbind_role(9999, "who") is None + + +def test_unbind_role_invalid_name_raises(): + store = EpisodeStore() + eid = store.store_episode("e", who=_hv(1)) + with pytest.raises(ValueError, match="Unknown role"): + store.unbind_role(eid, "why") + + +# ── role HV determinism ─────────────────────────────────────────────────────── + +def test_role_hvs_are_deterministic(): + """Two independent EpisodeStore instances must share identical role HVs.""" + s1 = EpisodeStore() + s2 = EpisodeStore() + for role in _ROLE_SEEDS: + assert (s1._role_hvs[role].data == s2._role_hvs[role].data).all() + + +def test_role_hvs_are_distinct(): + store = EpisodeStore() + roles = list(_ROLE_SEEDS.keys()) + for i in range(len(roles)): + for j in range(i + 1, len(roles)): + sim = store._role_hvs[roles[i]].cosine_similarity(store._role_hvs[roles[j]]) + assert abs(sim) < 0.1, f"Role HVs {roles[i]} and {roles[j]} too similar: {sim}" diff --git a/python/tests/test_graph_export.py b/python/tests/test_graph_export.py index fdcc753..20c1d29 100644 --- a/python/tests/test_graph_export.py +++ b/python/tests/test_graph_export.py @@ -5,17 +5,13 @@ import xml.etree.ElementTree as ET from pathlib import Path -import numpy as np import pytest from kohaku._pure import EpisodicMemory, HyperVector -from kohaku.decay import DecayConfig from kohaku.graph_export import ( GraphExportConfig, - MemoryEdge, MemoryGraph, MemoryGraphExporter, - MemoryNode, ) from kohaku.learning import ItemMemory diff --git a/python/tests/test_persistence.py b/python/tests/test_persistence.py index 80081ef..9adb32a 100644 --- a/python/tests/test_persistence.py +++ b/python/tests/test_persistence.py @@ -2,7 +2,6 @@ from __future__ import annotations import json -import struct from pathlib import Path import numpy as np diff --git a/python/tests/test_pure.py b/python/tests/test_pure.py index 2c44c37..2fdeba7 100644 --- a/python/tests/test_pure.py +++ b/python/tests/test_pure.py @@ -1,7 +1,6 @@ """Tests for the pure-Python HDC implementation.""" from __future__ import annotations -import pytest import numpy as np from kohaku._pure import HyperVector, EpisodicMemory, DIMS from kohaku._query import query, query_threshold diff --git a/python/tests/test_sleep.py b/python/tests/test_sleep.py index 820dd1a..6c8bcbd 100644 --- a/python/tests/test_sleep.py +++ b/python/tests/test_sleep.py @@ -2,7 +2,6 @@ from __future__ import annotations import threading -import time import numpy as np import pytest diff --git a/python/tests/test_validation.py b/python/tests/test_validation.py new file mode 100644 index 0000000..f7cfb53 --- /dev/null +++ b/python/tests/test_validation.py @@ -0,0 +1,204 @@ +"""Tests for kohaku.validation — write-time validation and poisoning defense.""" +from __future__ import annotations + +import pytest + +from kohaku._pure import DIMS, EpisodicMemory, HyperVector +from kohaku.validation import RateLimit, WriteValidator + + +def _hv(seed: int) -> HyperVector: + return HyperVector.random(DIMS, seed=seed) + + +# ── RateLimit construction ──────────────────────────────────────────────────── + +def test_rate_limit_max_stores_zero_raises(): + with pytest.raises(ValueError, match="max_stores"): + RateLimit(max_stores=0, window_seconds=60.0) + + +def test_rate_limit_window_zero_raises(): + with pytest.raises(ValueError, match="window_seconds"): + RateLimit(max_stores=10, window_seconds=0.0) + + +def test_rate_limit_negative_raises(): + with pytest.raises(ValueError): + RateLimit(max_stores=-1, window_seconds=60.0) + + +# ── WriteValidator construction ─────────────────────────────────────────────── + +def test_duplicate_threshold_zero_raises(): + mem = EpisodicMemory() + with pytest.raises(ValueError, match="duplicate_threshold"): + WriteValidator(mem, duplicate_threshold=0.0) + + +def test_duplicate_threshold_above_one_raises(): + mem = EpisodicMemory() + with pytest.raises(ValueError, match="duplicate_threshold"): + WriteValidator(mem, duplicate_threshold=1.01) + + +def test_duplicate_threshold_exactly_one_is_valid(): + mem = EpisodicMemory() + WriteValidator(mem, duplicate_threshold=1.0) # should not raise + + +# ── empty memory always accepts ─────────────────────────────────────────────── + +def test_empty_memory_accepts_any_key(): + mem = EpisodicMemory() + validator = WriteValidator(mem) + result = validator.validate(_hv(1)) + assert result.accepted is True + assert result.reason == "accepted" + assert result.nearest_similarity == 0.0 + assert result.nearest_label == "" + + +# ── novelty check ───────────────────────────────────────────────────────────── + +def test_identical_key_rejected_as_near_duplicate(): + mem = EpisodicMemory() + hv = _hv(1) + mem.store(hv, hv, "original") + validator = WriteValidator(mem, duplicate_threshold=0.99) + result = validator.validate(hv) + assert result.accepted is False + assert result.reason == "near_duplicate" + assert result.nearest_similarity == pytest.approx(1.0, abs=1e-4) + assert result.nearest_label == "original" + + +def test_orthogonal_key_accepted(): + mem = EpisodicMemory() + hv_a = _hv(1) + hv_b = _hv(999) + mem.store(hv_a, hv_a, "stored") + validator = WriteValidator(mem, duplicate_threshold=0.99) + result = validator.validate(hv_b) + assert result.accepted is True + + +def test_threshold_one_only_rejects_identical(): + """duplicate_threshold=1.0 means only cosine==1.0 is a duplicate.""" + mem = EpisodicMemory() + hv = _hv(1) + mem.store(hv, hv, "orig") + validator = WriteValidator(mem, duplicate_threshold=1.0) + + # Identical → rejected + result = validator.validate(hv) + assert result.accepted is False + + # Slightly different (noisy copy) → accepted + noisy_data = hv.data.copy() + noisy_data[:50] *= -1 # flip 0.5% of bits + noisy = HyperVector(noisy_data) + result2 = validator.validate(noisy) + assert result2.accepted is True + + +def test_nearest_similarity_populated_on_accept(): + """Even on accept, nearest_similarity reflects the closest stored entry (nonzero).""" + mem = EpisodicMemory() + hv_a = _hv(1) + mem.store(hv_a, hv_a, "close") + validator = WriteValidator(mem, duplicate_threshold=0.99) + hv_b = _hv(2) # different but not identical + result = validator.validate(hv_b) + assert result.accepted is True + # Cosine may be negative for random HVs; just confirm it's not the default 0.0 + assert result.nearest_similarity != 0.0 + assert result.nearest_label == "close" + + +# ── rate limit ──────────────────────────────────────────────────────────────── + +def test_rate_limit_second_call_rejected(): + mem = EpisodicMemory() + validator = WriteValidator( + mem, + rate_limits={"bot": RateLimit(max_stores=1, window_seconds=60.0)}, + ) + hv_a = _hv(1) + hv_b = _hv(999) + r1 = validator.validate(hv_a, source="bot") + assert r1.accepted is True + validator.record(source="bot") + + r2 = validator.validate(hv_b, source="bot") + assert r2.accepted is False + assert r2.reason == "rate_limit_exceeded" + + +def test_rate_limit_different_source_unaffected(): + mem = EpisodicMemory() + validator = WriteValidator( + mem, + rate_limits={"bot": RateLimit(max_stores=1, window_seconds=60.0)}, + ) + hv = _hv(1) + validator.validate(hv, source="bot") + validator.record(source="bot") + + # A different source (not rate-limited) should pass freely. + result = validator.validate(_hv(999), source="human") + assert result.accepted is True + + +def test_validate_does_not_update_rate_limit(): + """validate() alone must NOT consume a rate-limit slot.""" + mem = EpisodicMemory() + validator = WriteValidator( + mem, + rate_limits={"bot": RateLimit(max_stores=1, window_seconds=60.0)}, + ) + hv = _hv(1) + # Two validate() calls without record() — both should pass. + r1 = validator.validate(hv, source="bot") + r2 = validator.validate(_hv(2), source="bot") + assert r1.accepted is True + assert r2.accepted is True + + +# ── validate_and_store ──────────────────────────────────────────────────────── + +def test_validate_and_store_accepted_stores_entry(): + mem = EpisodicMemory() + validator = WriteValidator(mem) + hv = _hv(1) + result, eid = validator.validate_and_store(hv, hv, "stored") + assert result.accepted is True + assert eid is not None + assert len(mem) == 1 + + +def test_validate_and_store_rejected_does_not_store(): + mem = EpisodicMemory() + hv = _hv(1) + mem.store(hv, hv, "original") + validator = WriteValidator(mem, duplicate_threshold=0.99) + result, eid = validator.validate_and_store(hv, hv, "dup") + assert result.accepted is False + assert eid is None + assert len(mem) == 1 # unchanged + + +def test_validate_and_store_records_rate_limit(): + """validate_and_store should call record() so the slot is consumed.""" + mem = EpisodicMemory() + validator = WriteValidator( + mem, + rate_limits={"bot": RateLimit(max_stores=1, window_seconds=60.0)}, + ) + hv_a = _hv(1) + hv_b = _hv(999) + r1, _ = validator.validate_and_store(hv_a, hv_a, "first", source="bot") + assert r1.accepted is True + r2, _ = validator.validate_and_store(hv_b, hv_b, "second", source="bot") + assert r2.accepted is False + assert r2.reason == "rate_limit_exceeded" diff --git a/tests/test_compaction.py b/tests/test_compaction.py index 14a22f7..811dd05 100644 --- a/tests/test_compaction.py +++ b/tests/test_compaction.py @@ -1,7 +1,7 @@ """Tests for memory compaction and deduplication.""" import pytest from kohaku.compaction import cosine_similarity, find_duplicates, deduplicate, compact -from kohaku._pure import EpisodicMemory, HyperVector, MemoryEntry +from kohaku._pure import EpisodicMemory, HyperVector DIM = 64 diff --git a/tests/test_streaming_consolidation.py b/tests/test_streaming_consolidation.py index 6956330..66a71b3 100644 --- a/tests/test_streaming_consolidation.py +++ b/tests/test_streaming_consolidation.py @@ -1,5 +1,4 @@ """Tests for StreamingConsolidator.""" -import time import pytest from kohaku.streaming import StreamingConsolidator from kohaku._pure import EpisodicMemory, HyperVector