Skip to content

Commit 1ebd6bb

Browse files
committed
feat: GOAP integration, MC robustness gate, learned posterior variance, spawn-aware planning
1 parent 3c021e1 commit 1ebd6bb

7 files changed

Lines changed: 232 additions & 53 deletions

File tree

docs/compass-hero.png

-1.17 MB
Loading

src/brain/completion.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020

2121
log = logging.getLogger(__name__)
2222

23+
# Routines returning RUNNING longer than this are force-exited.
24+
# Must exceed the combat casting pipeline (2-3 s legit) to avoid
25+
# killing active fights mid-cast.
26+
HARD_KILL_THRESHOLD_MS = 5000
27+
2328

2429
def tick_active_routine(brain: Brain, state: GameState, now: float) -> None:
2530
"""Tick the active routine and handle SUCCESS/FAILURE/hard-kill outcomes."""
@@ -36,7 +41,7 @@ def tick_active_routine(brain: Brain, state: GameState, now: float) -> None:
3641
brain._ticked_routine_name = brain._active_name
3742

3843
if status == RoutineStatus.RUNNING:
39-
if brain.routine_tick_ms > 5000:
44+
if brain.routine_tick_ms > HARD_KILL_THRESHOLD_MS:
4045
hard_kill_routine(brain, state, now)
4146
return
4247

@@ -128,9 +133,7 @@ def notify_cycle_tracker(brain: Brain, state: GameState, status: RoutineStatus)
128133

129134

130135
def hard_kill_routine(brain: Brain, state: GameState, now: float) -> None:
131-
"""Force-exit a routine that returned RUNNING but took >5 s."""
132-
# Threshold must exceed combat casting pipeline (2-3s legit)
133-
# to avoid killing active fights mid-cast.
136+
"""Force-exit a routine that exceeded HARD_KILL_THRESHOLD_MS."""
134137
log.error("[DECISION] HARD KILL: %s took %.0fms, forcing exit", brain._active_name, brain.routine_tick_ms)
135138
assert brain._active is not None
136139
brain._active.failure_reason = "hard_kill"

src/brain/goap/actions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,23 @@ def apply_effects(self, ws: PlanWorldState) -> PlanWorldState:
199199
return ws.with_changes(targets_available=1)
200200

201201
def estimate_cost(self, ctx: AgentContext | None) -> float:
202+
"""Use spawn prediction to reduce wander cost when respawns are imminent.
203+
204+
If the spawn predictor has enough data, the expected time-to-next-respawn
205+
in nearby cells replaces the default heuristic. This makes the planner
206+
prefer wander-then-fight plans when targets are predicted to appear soon,
207+
converting random wandering into directed positioning.
208+
"""
209+
if not ctx or not ctx.spawn_predictor:
210+
return _DEFAULT_COSTS["wander"]
211+
import time as _time
212+
213+
best = ctx.spawn_predictor.best_cells(3, _time.time())
214+
if best:
215+
# Use the shortest predicted wait among nearby cells
216+
min_wait = min(secs for _, secs in best)
217+
# Blend: at least 5s (travel time), at most the default
218+
return max(5.0, min(min_wait, _DEFAULT_COSTS["wander"]))
202219
return _DEFAULT_COSTS["wander"]
203220

204221

src/brain/goap/planner.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
SATISFACTION_THRESHOLD = 0.70 # goal is "achieved enough" at this level
4242
PLAN_BUDGET_MS = 50.0 # max time for plan generation
4343
MC_ROLLOUTS = 20 # Monte Carlo rollouts per candidate plan
44-
MC_NOISE_SIGMA = 0.15 # noise on action effects during rollouts
44+
MC_NOISE_SIGMA = 0.15 # fallback noise when no learned variance available
45+
MC_ROBUSTNESS_THRESHOLD = 0.50 # reject plans below this MC satisfaction
4546

4647

4748
@dataclass(slots=True)
@@ -402,24 +403,28 @@ def _mc_evaluate(
402403
"""Evaluate a candidate plan via Monte Carlo rollouts.
403404
404405
Runs MC_ROLLOUTS stochastic simulations of the plan. In each rollout,
405-
action effects are perturbed with Gaussian noise on continuous fields
406-
(hp_pct, mana_pct) to simulate outcome uncertainty. Returns the mean
407-
goal satisfaction across rollouts.
406+
action effects are perturbed with noise drawn from learned posterior
407+
variance (encounter history) when available, or fixed sigma as
408+
fallback. Returns the mean goal satisfaction across rollouts.
408409
409410
A plan that achieves high satisfaction across noisy rollouts is robust
410411
to the inherent uncertainty in combat outcomes, rest durations, etc.
411412
"""
412413
if not plan_actions:
413414
return goal.satisfaction(start)
414415

416+
# Derive noise sigma from learned posterior variance when available.
417+
# Wider posteriors (less data) produce more noise, naturally penalising
418+
# plans that depend on uncertain outcomes.
419+
hp_sigma, mana_sigma = self._learned_mc_sigma(ctx)
420+
415421
total_sat = 0.0
416422
for _ in range(MC_ROLLOUTS):
417423
ws = start
418424
for action in plan_actions:
419425
ws = action.apply_effects(ws)
420-
# Stochastic perturbation on continuous resource fields
421-
hp_noise = random.gauss(0, MC_NOISE_SIGMA)
422-
mana_noise = random.gauss(0, MC_NOISE_SIGMA)
426+
hp_noise = random.gauss(0, hp_sigma)
427+
mana_noise = random.gauss(0, mana_sigma)
423428
ws = ws.with_changes(
424429
hp_pct=max(0.0, min(1.0, ws.hp_pct + hp_noise)),
425430
mana_pct=max(0.0, min(1.0, ws.mana_pct + mana_noise)),
@@ -428,6 +433,37 @@ def _mc_evaluate(
428433

429434
return total_sat / MC_ROLLOUTS
430435

436+
@staticmethod
437+
def _learned_mc_sigma(ctx: AgentContext | None) -> tuple[float, float]:
438+
"""Derive MC noise sigma from encounter posterior variance.
439+
440+
When fight history has enough data, the posterior variance on HP loss
441+
and mana cost reflects actual outcome uncertainty. Wider posteriors
442+
(fewer observations) produce larger sigma, so plans that depend on
443+
poorly-known actions are penalised more heavily.
444+
445+
Falls back to MC_NOISE_SIGMA when no learned data is available.
446+
"""
447+
if not ctx or not ctx.fight_history:
448+
return MC_NOISE_SIGMA, MC_NOISE_SIGMA
449+
all_stats = ctx.fight_history.get_all_stats()
450+
if not all_stats:
451+
return MC_NOISE_SIGMA, MC_NOISE_SIGMA
452+
# Average posterior std across known entity types
453+
hp_vars: list[float] = []
454+
mana_vars: list[float] = []
455+
for stats in all_stats.values():
456+
if stats.danger_post_var > 0:
457+
hp_vars.append(stats.danger_post_var)
458+
if stats.mana_post_var > 0:
459+
mana_vars.append(stats.mana_post_var)
460+
hp_sigma = (sum(hp_vars) / len(hp_vars)) ** 0.5 if hp_vars else MC_NOISE_SIGMA
461+
mana_sigma = (sum(mana_vars) / len(mana_vars)) ** 0.5 if mana_vars else MC_NOISE_SIGMA
462+
# Clamp to reasonable range
463+
hp_sigma = max(0.02, min(0.40, hp_sigma))
464+
mana_sigma = max(0.02, min(0.40, mana_sigma))
465+
return hp_sigma, mana_sigma
466+
431467
# -- Internal: A* Search ----------------------------------------------------
432468

433469
def _search(self, start: PlanWorldState, goal: Goal, ctx: AgentContext | None) -> Plan | None:
@@ -462,9 +498,20 @@ def _search(self, start: PlanWorldState, goal: Goal, ctx: AgentContext | None) -
462498
# Goal test: deterministic satisfaction check
463499
sat = goal.satisfaction(node.state)
464500
if sat >= SATISFACTION_THRESHOLD:
465-
# Monte Carlo robustness check: verify the plan holds
466-
# under stochastic action outcomes.
501+
# Monte Carlo robustness gate: reject plans that don't hold
502+
# under stochastic action outcomes. Uses learned posterior
503+
# variance when available, fixed sigma as fallback.
467504
mc_sat = self._mc_evaluate(node.actions, start, goal, ctx)
505+
if mc_sat < MC_ROBUSTNESS_THRESHOLD:
506+
log.log(
507+
VERBOSE,
508+
"[GOAP] Plan rejected (mc_sat=%.2f < %.2f): %d steps, cost=%.1f",
509+
mc_sat,
510+
MC_ROBUSTNESS_THRESHOLD,
511+
len(node.actions),
512+
node.g_cost,
513+
)
514+
continue # keep searching for a more robust plan
468515
log.log(
469516
VERBOSE,
470517
"[GOAP] Plan found: %d steps, %d nodes, cost=%.1f, sat=%.2f, mc_sat=%.2f",

src/brain/scoring_phases.py

Lines changed: 31 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,29 @@
2121

2222
log = logging.getLogger(__name__)
2323

24+
# Score multiplier applied when GOAP planner suggests a specific action
25+
GOAP_BOOST = 1.5
26+
27+
28+
def _resolve_phase_context(brain: Brain) -> tuple[str, str]:
29+
"""Return (session_phase, goap_hint) from brain diagnostics."""
30+
phase = "grinding"
31+
goap_hint = ""
32+
if brain._ctx and hasattr(brain._ctx, "diag") and brain._ctx.diag:
33+
pd = getattr(brain._ctx.diag, "phase_detector", None)
34+
if pd is not None:
35+
phase = pd.current_phase
36+
goap_hint = getattr(brain._ctx.diag, "goap_suggestion", "")
37+
return phase, goap_hint
38+
39+
40+
def _apply_modifiers(score: float, phase: str, rule_name: str, goap_hint: str) -> float:
41+
"""Apply session-phase modifier and GOAP boost to a raw score."""
42+
score *= get_phase_modifier(phase, rule_name)
43+
if goap_hint and rule_name == goap_hint and score > 0:
44+
score *= GOAP_BOOST
45+
return score
46+
2447

2548
def compute_divergence(brain: Brain, state: GameState, now: float, binary_winner: str) -> None:
2649
"""Phase 1: compute scores for all rules, log when score-based
@@ -74,25 +97,13 @@ def select_by_tier(
7497
continue
7598
tier_groups[r.tier].append(r)
7699

77-
# Get session phase for contextual score modifiers
78-
phase = "grinding"
79-
goap_hint = ""
80-
if brain._ctx and hasattr(brain._ctx, "diag") and brain._ctx.diag:
81-
pd = getattr(brain._ctx.diag, "phase_detector", None)
82-
if pd is not None:
83-
phase = pd.current_phase
84-
goap_hint = getattr(brain._ctx.diag, "goap_suggestion", "")
100+
phase, goap_hint = _resolve_phase_context(brain)
85101

86102
for tier in sorted(tier_groups):
87103
scored: list[tuple[float, RuleDef]] = []
88104
for r in tier_groups[tier]:
89105
t0 = time.perf_counter()
90-
s = r.score_fn(state)
91-
# Apply session phase modifier (startup, incident, idle, etc.)
92-
s *= get_phase_modifier(phase, r.name)
93-
# GOAP planner boost: prefer the planned action
94-
if goap_hint and r.name == goap_hint and s > 0:
95-
s *= 1.5 # 50% score boost for GOAP-suggested action
106+
s = _apply_modifiers(r.score_fn(state), phase, r.name, goap_hint)
96107
rule_times[r.name] = (time.perf_counter() - t0) * 1000
97108
rule_eval[r.name] = f"{s:.2f}" if s > 0 else "0"
98109
diag_results.append(f"{r.name}={s:.2f}")
@@ -114,14 +125,7 @@ def select_weighted(
114125
emergency: list[tuple[float, RuleDef]] = []
115126
normal: list[tuple[float, RuleDef]] = []
116127

117-
# Session phase for contextual modifiers
118-
phase = "grinding"
119-
goap_hint = ""
120-
if brain._ctx and hasattr(brain._ctx, "diag") and brain._ctx.diag:
121-
pd = getattr(brain._ctx.diag, "phase_detector", None)
122-
if pd is not None:
123-
phase = pd.current_phase
124-
goap_hint = getattr(brain._ctx.diag, "goap_suggestion", "")
128+
phase, goap_hint = _resolve_phase_context(brain)
125129

126130
for r in brain._rules:
127131
if r.name in brain._cooldowns and now < brain._cooldowns[r.name]:
@@ -131,12 +135,7 @@ def select_weighted(
131135
rule_times[r.name] = 0.0
132136
continue
133137
t0 = time.perf_counter()
134-
s = r.score_fn(state)
135-
# Apply session phase modifier
136-
s *= get_phase_modifier(phase, r.name)
137-
# GOAP planner boost
138-
if goap_hint and r.name == goap_hint and s > 0:
139-
s *= 1.5
138+
s = _apply_modifiers(r.score_fn(state), phase, r.name, goap_hint)
140139
rule_times[r.name] = (time.perf_counter() - t0) * 1000
141140
weighted = r.weight * s
142141
rule_eval[r.name] = f"{weighted:.1f}" if s > 0 else "0"
@@ -171,13 +170,7 @@ def select_with_considerations(
171170
emergency: list[tuple[float, RuleDef]] = []
172171
normal: list[tuple[float, RuleDef]] = []
173172

174-
phase = "grinding"
175-
goap_hint = ""
176-
if brain._ctx and hasattr(brain._ctx, "diag") and brain._ctx.diag:
177-
pd = getattr(brain._ctx.diag, "phase_detector", None)
178-
if pd is not None:
179-
phase = pd.current_phase
180-
goap_hint = getattr(brain._ctx.diag, "goap_suggestion", "")
173+
phase, goap_hint = _resolve_phase_context(brain)
181174

182175
for r in brain._rules:
183176
if r.name in brain._cooldowns and now < brain._cooldowns[r.name]:
@@ -189,12 +182,10 @@ def select_with_considerations(
189182
t0 = time.perf_counter()
190183
# Phase 4: prefer considerations over score_fn when defined
191184
if r.considerations and brain._ctx:
192-
s = score_from_considerations(r.considerations, state, brain._ctx)
185+
raw = score_from_considerations(r.considerations, state, brain._ctx)
193186
else:
194-
s = r.score_fn(state)
195-
s *= get_phase_modifier(phase, r.name)
196-
if goap_hint and r.name == goap_hint and s > 0:
197-
s *= 1.5
187+
raw = r.score_fn(state)
188+
s = _apply_modifiers(raw, phase, r.name, goap_hint)
198189
rule_times[r.name] = (time.perf_counter() - t0) * 1000
199190
weighted = r.weight * s
200191
rule_eval[r.name] = f"{weighted:.1f}" if s > 0 else "0"

tests/test_goap_actions.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,54 @@ def test_defeat_effects_use_class_defaults(self) -> None:
426426
# ---------------------------------------------------------------------------
427427

428428

429+
class TestWanderActionSpawnPrediction:
430+
"""WanderAction.estimate_cost uses spawn predictions when available."""
431+
432+
def _wander(self) -> WanderAction:
433+
return WanderAction(name="wander", routine_name="WANDER")
434+
435+
def test_no_ctx_returns_default(self) -> None:
436+
assert self._wander().estimate_cost(None) == 30.0
437+
438+
def test_no_spawn_predictor_returns_default(self) -> None:
439+
from types import SimpleNamespace
440+
441+
ctx = SimpleNamespace(spawn_predictor=None)
442+
assert self._wander().estimate_cost(ctx) == 30.0
443+
444+
def test_imminent_respawn_reduces_cost(self) -> None:
445+
from types import SimpleNamespace
446+
447+
from core.types import Point
448+
449+
predictor = SimpleNamespace(
450+
best_cells=lambda n, now: [(Point(100, 100, 0), 8.0)],
451+
)
452+
ctx = SimpleNamespace(spawn_predictor=predictor)
453+
cost = self._wander().estimate_cost(ctx)
454+
# 8s predicted wait -> cost should be 8.0 (clamped above 5.0)
455+
assert cost == 8.0
456+
457+
def test_very_short_wait_clamps_to_minimum(self) -> None:
458+
from types import SimpleNamespace
459+
460+
from core.types import Point
461+
462+
predictor = SimpleNamespace(
463+
best_cells=lambda n, now: [(Point(50, 50, 0), 2.0)],
464+
)
465+
ctx = SimpleNamespace(spawn_predictor=predictor)
466+
cost = self._wander().estimate_cost(ctx)
467+
assert cost == 5.0 # floor of 5s travel time
468+
469+
def test_empty_predictions_returns_default(self) -> None:
470+
from types import SimpleNamespace
471+
472+
predictor = SimpleNamespace(best_cells=lambda n, now: [])
473+
ctx = SimpleNamespace(spawn_predictor=predictor)
474+
assert self._wander().estimate_cost(ctx) == 30.0
475+
476+
429477
class TestAcquireActionCost:
430478
def test_estimate_cost(self) -> None:
431479
assert _acquire().estimate_cost(None) == 5.0

0 commit comments

Comments
 (0)