Skip to content

Commit c159b2d

Browse files
committed
feat: select wave target by fewest outputs, tiebreak by success rate
Change weakest_stage to prioritize stages with the fewest distinct outputs (the bottleneck), breaking ties by lowest success rate (distinct_outputs/attempts). This avoids sunk-cost traps where a hard stage with low success rate keeps getting targeted over underexplored stages. Also: print immediate feedback on Ctrl-C before clean shutdown.
1 parent 1a41d28 commit c159b2d

3 files changed

Lines changed: 38 additions & 15 deletions

File tree

vxsort/smallsort/codegen/src/transition_table.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,32 +203,33 @@ def stage_stats(self, stage: int) -> dict:
203203
}
204204

205205
def weakest_stage(self, exclude_exhausted: set[int] | None = None) -> int | None:
206-
"""Return the index of the stage with the lowest success rate.
206+
"""Return the stage most in need of exploration.
207+
208+
Selection criteria (in priority order):
209+
1. Fewest distinct outputs (the bottleneck stage)
210+
2. Lowest success rate (distinct_outputs / attempts) as tiebreaker
211+
3. Earliest stage index as final tiebreaker
207212
208213
Parameters
209214
----------
210215
exclude_exhausted :
211216
Stage indices to skip. If all stages are excluded (or the
212217
table is empty), returns ``None``.
213-
214-
Ties are broken by earliest stage index. Stages with zero attempts
215-
are treated as having a success rate of 0.0.
216218
"""
217219
excluded = exclude_exhausted or set()
218220
best_idx: int | None = None
219-
best_rate = float("inf")
221+
best_key = (float("inf"), float("inf"))
220222

221223
for i, sd in enumerate(self.stages):
222224
if i in excluded:
223225
continue
224226

225-
if sd.attempts > 0:
226-
rate = len(sd.unique_outputs) / sd.attempts
227-
else:
228-
rate = 0.0
227+
n_outputs = len(sd.unique_outputs)
228+
rate = n_outputs / sd.attempts if sd.attempts > 0 else 0.0
229+
key = (n_outputs, rate)
229230

230-
if rate < best_rate:
231-
best_rate = rate
231+
if key < best_key:
232+
best_key = key
232233
best_idx = i
233234

234235
return best_idx

vxsort/smallsort/codegen/src/wave_engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,12 @@ def run(self) -> dict:
837837

838838
def _sigint_handler(_signum, _frame):
839839
self._interrupted = True
840+
import sys
841+
842+
print(
843+
"\nCtrl-C detected — finishing current work and saving checkpoint...",
844+
file=sys.stderr,
845+
)
840846

841847
signal.signal(signal.SIGINT, _sigint_handler)
842848

vxsort/smallsort/codegen/tests/test_transition_table.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,23 +386,23 @@ def test_all_empty_returns_first(self):
386386
tt = TransitionTable(num_stages=3)
387387
assert tt.weakest_stage() == 0
388388

389-
def test_returns_lowest_success_rate(self):
389+
def test_returns_fewest_outputs(self):
390390
tt = TransitionTable(num_stages=3)
391391
inp = _vs([0, 1, 2, 3], [4, 5, 6, 7])
392392

393-
# Stage 0: 2 outputs from 4 attempts => 50%
393+
# Stage 0: 2 outputs from 4 attempts
394394
tt.record_attempt(0, count=4)
395395
out0a = _vs([0, 1, 2, 3], [4, 5, 6, 7])
396396
out0b = _vs([1, 0, 3, 2], [5, 4, 7, 6])
397397
tt.add_transition(0, inp, out0a, _make_gadget(top_args={"ctrl": 1}))
398398
tt.add_transition(0, inp, out0b, _make_gadget(top_args={"ctrl": 2}))
399399

400-
# Stage 1: 1 output from 4 attempts => 25%
400+
# Stage 1: 1 output from 4 attempts — fewest outputs, selected
401401
tt.record_attempt(1, count=4)
402402
out1a = _vs([0, 1, 2, 3], [4, 5, 6, 7])
403403
tt.add_transition(1, inp, out1a, _make_gadget(top_args={"ctrl": 3}))
404404

405-
# Stage 2: 3 outputs from 4 attempts => 75%
405+
# Stage 2: 3 outputs from 4 attempts
406406
tt.record_attempt(2, count=4)
407407
out2a = _vs([0, 1, 2, 3], [4, 5, 6, 7])
408408
out2b = _vs([1, 0, 3, 2], [5, 4, 7, 6])
@@ -413,6 +413,22 @@ def test_returns_lowest_success_rate(self):
413413

414414
assert tt.weakest_stage() == 1
415415

416+
def test_tiebreak_by_success_rate(self):
417+
"""When two stages have the same output count, pick the one struggling more."""
418+
tt = TransitionTable(num_stages=2)
419+
inp = _vs([0, 1, 2, 3], [4, 5, 6, 7])
420+
out_a = _vs([1, 0, 3, 2], [5, 4, 7, 6])
421+
422+
# Stage 0: 1 output from 10 attempts (10% rate)
423+
tt.record_attempt(0, count=10)
424+
tt.add_transition(0, inp, out_a, _make_gadget(top_args={"ctrl": 1}))
425+
426+
# Stage 1: 1 output from 100 attempts (1% rate) — same outputs, worse rate
427+
tt.record_attempt(1, count=100)
428+
tt.add_transition(1, inp, out_a, _make_gadget(top_args={"ctrl": 2}))
429+
430+
assert tt.weakest_stage() == 1 # same outputs, lower success rate
431+
416432
def test_tie_breaks_by_earliest_stage(self):
417433
tt = TransitionTable(num_stages=3)
418434
inp = _vs([0, 1, 2, 3], [4, 5, 6, 7])

0 commit comments

Comments
 (0)