From 473ede5094d7864d1b342cb99a99e5f6242f8a76 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Wed, 27 May 2026 11:28:56 -0400 Subject: [PATCH 01/15] Add Arneodo-2021 syrinx-ODE parameterization with autonomous integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an alternative to the full-polynomial Ouroboros that parameterizes the second derivative with the Arneodo et al. 2021 biomechanical syrinx ODE (ẍ = γ²α + γ²βx + γ²x² − γ²x³ − γδẋ − γxẋ − γx²ẋ). The network outputs the control time series α(t), β(t), δ(t) and learns γ as a single positive scalar; the nonlinear coefficients are fixed to γ as in the paper. The δ term is a paper-consistent linear-damping extension (zero-init, so the model starts at the strict paper form). Because the RHS is an explicit function of state, a trained model can be integrated fully autonomously, feeding the generated (x, ẋ) back in. - model/model.py: ArneodoOuroboros (α/β/δ encoders + scalar γ, rescaled-time RHS, get_funcs); α/β/δ heads zero-initialized for conditioning. - train/model_cv.py: train_arneodo (single fit, no λ-CV, reg_weights=False). - train/train_model.py: parameterization="poly"|"arneodo" dispatch. - train/train.py: define total_loss before the reg_weights block so the unregularized path runs; save_model/load_model record & dispatch on a "parameterization" tag (poly checkpoints unchanged). - train/eval.py: integrate_model_autonomous (closed-loop, state-feedback RHS). - examples/: arneodo_autonomous.py (data-gen → train → autonomous rollout), plot_autonomous.py (waveforms + spectrograms), plot_drives.py (learned vs data-generation control inputs, affine maps inverted). Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 3 + examples/arneodo_autonomous.py | 167 ++++++++++++++++++++++ examples/plot_autonomous.py | 137 ++++++++++++++++++ examples/plot_drives.py | 163 ++++++++++++++++++++++ model/model.py | 245 +++++++++++++++++++++++++++++++++ train/eval.py | 111 +++++++++++++++ train/model_cv.py | 112 ++++++++++++++- train/train.py | 58 +++++--- train/train_model.py | 49 +++++-- 9 files changed, 1009 insertions(+), 36 deletions(-) create mode 100644 examples/arneodo_autonomous.py create mode 100644 examples/plot_autonomous.py create mode 100644 examples/plot_drives.py diff --git a/.gitignore b/.gitignore index a3c3180..43eb49e 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# example run outputs (generated data, checkpoints, plots) +arneodo_run/ diff --git a/examples/arneodo_autonomous.py b/examples/arneodo_autonomous.py new file mode 100644 index 0000000..d601df5 --- /dev/null +++ b/examples/arneodo_autonomous.py @@ -0,0 +1,167 @@ +""" +End-to-end demo for the Arneodo-2021 parameterization + autonomous integration. + +Pipeline +-------- +1. Generate a small synthetic dataset from the Mindlin model (`generate_vocal_dataset`). +2. Train an `ArneodoOuroboros` model (the biomechanical syrinx ODE parameterization), + exactly as the polynomial model is trained -- teacher-forced one-step prediction of + the second derivative -- via `train_model(..., parameterization="arneodo")`. +3. Run *fully autonomous* integration on a held-out segment: the ODE is integrated while + feeding the generated state (x, x') back into the right-hand side, with only the + control time series alpha(t), beta(t) (produced once by the encoder) and the initial + condition coming from data. Plot generated vs. true waveform and report a simple + pitch / correlation comparison. + +The defaults below are a fast smoke test (a handful of vocalizations, a few epochs, a +short integration window). Scale up `n_vocs`, `n_epochs`, `batch_size`, and +`n_integrate` for a real run. + +Run from the repo root: + python -m examples.arneodo_autonomous --out-dir ./arneodo_run +(needs a working CUDA torch build -- the model is hardwired to CUDA.) +""" + +import argparse +import os + +# headless plotting; model_vis sets text.usetex=True at import, which breaks on +# machines without LaTeX -- force it off after matplotlib is configured. +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np +import torch +import jax.random as jr + +from data.generate_data_gabo import generate_vocal_dataset +from data.load_data import get_segmented_audio +from train.train_model import train_model +from train.eval import integrate_model_autonomous + +plt.rcParams["text.usetex"] = False + + +def peak_freq(x: np.ndarray, dt: float) -> float: + """dominant (non-DC) frequency of a 1-D signal, in Hz.""" + x = x - np.mean(x) + spec = np.abs(np.fft.rfft(x)) + freqs = np.fft.rfftfreq(len(x), d=dt) + spec[0] = 0.0 + return float(freqs[np.argmax(spec)]) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--out-dir", default="./arneodo_run", help="output directory") + parser.add_argument("--n-vocs", type=int, default=10, help="synthetic vocalizations to generate") + parser.add_argument("--n-epochs", type=int, default=3, help="training epochs") + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--context-len", type=float, default=0.25, help="chunk length (s)") + parser.add_argument("--n-integrate", type=int, default=4000, help="samples to autonomously integrate") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--method", default="rk4", help="ODE integration method") + args = parser.parse_args() + + out_dir = os.path.abspath(args.out_dir) + data_dir = os.path.join(out_dir, "gabo_data") + model_dir = os.path.join(out_dir, "model") + os.makedirs(out_dir, exist_ok=True) + os.makedirs(model_dir, exist_ok=True) + + # 1. generate synthetic Mindlin data (audio .wav + onset/offset .txt) --------------- + if not os.path.isdir(data_dir) or len( + [f for f in os.listdir(data_dir) if f.endswith(".wav")] + ) < args.n_vocs: + print(f"Generating {args.n_vocs} synthetic vocalizations into {data_dir} ...") + generate_vocal_dataset( + jr.PRNGKey(args.seed), + n_vocs=args.n_vocs, + audio_loc=data_dir, + seg_loc=data_dir, + func_loc=data_dir, + ) + else: + print(f"Reusing existing data in {data_dir}") + + # 2. train the Arneodo model -------------------------------------------------------- + print("Training ArneodoOuroboros ...") + model = train_model( + audio_dirs=[data_dir], + seg_dirs=[data_dir], + model_dir=model_dir, + max_vocs=max(args.n_vocs * 4, 50), + context_len=args.context_len, + seed=args.seed, + batch_size=args.batch_size, + n_epochs=args.n_epochs, + save_freq=max(args.n_epochs // 2, 1), + parameterization="arneodo", + ) + model.eval() + + # 3. autonomous integration on a held-out segment ----------------------------------- + chunks, sr = get_segmented_audio( + data_dir, + data_dir, + max_vocs=args.n_vocs * 4, + context_len=args.context_len, + seed=args.seed + 1, # different seed -> different shuffle than training selection + training=True, + extend=True, + shuffle_order=True, + ) + dt = 1 / sr + segment = np.asarray(chunks[0]).squeeze() + n = min(args.n_integrate, len(segment)) + segment = segment[:n] + + print(f"\nAutonomously integrating {n} samples ({n * dt * 1e3:.1f} ms) ...") + x_gen = integrate_model_autonomous( + model, segment, dt, method=args.method, detrend=True, verbose=True + ) + print() # newline after the progress \r + + # detrend the reference the same way for a fair visual comparison + from train.eval import correct + + true_detrended = correct(segment.astype(np.float64)) + + f_true = peak_freq(true_detrended, dt) + f_gen = peak_freq(x_gen, dt) + corr = float(np.corrcoef(true_detrended[: len(x_gen)], x_gen)[0, 1]) + print(f"peak frequency true: {f_true:7.1f} Hz generated: {f_gen:7.1f} Hz") + print(f"waveform correlation (true vs generated): {corr:+.3f}") + print(f"generated range: [{x_gen.min():.3g}, {x_gen.max():.3g}]") + + # 4. plot --------------------------------------------------------------------------- + t_ms = np.arange(len(x_gen)) * dt * 1e3 + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), sharex=False) + ax1.plot(t_ms, true_detrended[: len(x_gen)], label="true (detrended)", color="tab:orange") + ax1.plot(t_ms, x_gen, label="autonomous", color="tab:blue", alpha=0.8) + ax1.set_title( + f"Autonomous integration | peak freq true {f_true:.0f} Hz / gen {f_gen:.0f} Hz | corr {corr:+.2f}" + ) + ax1.set_xlabel("time (ms)") + ax1.set_ylabel("a.u.") + ax1.legend() + + zoom = slice(0, min(len(x_gen), int(round(0.02 / dt)))) # first 20 ms + ax2.plot(t_ms[zoom], true_detrended[: len(x_gen)][zoom], color="tab:orange", label="true") + ax2.plot(t_ms[zoom], x_gen[zoom], color="tab:blue", alpha=0.8, label="autonomous") + ax2.set_title("first 20 ms (zoom)") + ax2.set_xlabel("time (ms)") + ax2.set_ylabel("a.u.") + ax2.legend() + + fig.tight_layout() + plot_path = os.path.join(out_dir, "autonomous_vs_true.svg") + fig.savefig(plot_path) + plt.close(fig) + print(f"\nSaved comparison plot to {plot_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/plot_autonomous.py b/examples/plot_autonomous.py new file mode 100644 index 0000000..5c46e86 --- /dev/null +++ b/examples/plot_autonomous.py @@ -0,0 +1,137 @@ +""" +Plot target vs. autonomously-generated waveforms and their spectrograms for a trained +`ArneodoOuroboros` model. + +Loads the most recent checkpoint under , picks a held-out segment from the gabo +dataset, runs `integrate_model_autonomous`, and writes a 2x2 figure: + row 0: waveforms (true | autonomous) + row 1: spectrograms (true | autonomous) + +Run from the repo root: + python -m examples.plot_autonomous --out-dir ./arneodo_run +""" + +import argparse +import os + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np +import torch +from scipy.signal import spectrogram + +from data.load_data import get_segmented_audio +from train.train import load_model +from train.eval import integrate_model_autonomous, correct + +plt.rcParams["text.usetex"] = False + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--out-dir", default="./arneodo_run") + parser.add_argument("--n-integrate", type=int, default=11000, help="samples to integrate") + parser.add_argument("--pre-onset-ms", type=float, default=10.0, + help="start the rollout this many ms before the vocalization onset") + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--fmax", type=float, default=10000.0, help="max spectrogram freq (Hz)") + parser.add_argument("--method", default="rk4") + args = parser.parse_args() + + out_dir = os.path.abspath(args.out_dir) + data_dir = os.path.join(out_dir, "gabo_data") + model_dir = os.path.join(out_dir, "model", "arneodo") + + # load the trained model ----------------------------------------------------------- + model, _, _, epoch = load_model(model_dir) + model.eval() + print(f"loaded {type(model).__name__} (epoch {epoch})") + + # held-out vocalization, windowed to start `pre_onset_ms` BEFORE the first onset, so the + # rollout is seeded from (near) rest just before vocalization and must spin the + # oscillation up itself. analysis mode returns aud[onset - padding : offset], so passing + # padding = pre_onset gives exactly that window (scale-consistent with training loading). + pre_onset_s = args.pre_onset_ms / 1e3 + chunks, sr = get_segmented_audio( + data_dir, + data_dir, + max_vocs=40, + seed=args.seed + 1, + training=False, + padding=pre_onset_s, + shuffle_order=True, + ) + dt = 1 / sr + segment = np.asarray(chunks[0]).squeeze() + n = min(args.n_integrate, len(segment)) + segment = segment[:n] + + print( + f"rollout starts {args.pre_onset_ms:.0f} ms before onset; " + f"autonomously integrating {n} samples ({n * dt * 1e3:.1f} ms) ..." + ) + x_gen = integrate_model_autonomous( + model, segment, dt, method=args.method, detrend=True, verbose=True + ) + print() + + # detrend the reference the same way the autonomous output is detrended + true = correct(segment.astype(np.float64))[: len(x_gen)] + t_ms = np.arange(len(x_gen)) * dt * 1e3 + + def spec(x): + nperseg = min(256, len(x)) + f, tt, Sxx = spectrogram( + x, fs=sr, nperseg=nperseg, noverlap=int(nperseg * 0.9) + ) + Sxx = 10 * np.log10(Sxx + 1e-12) + return f, tt * 1e3, Sxx + + f_t, tt_t, S_t = spec(true) + f_g, tt_g, S_g = spec(x_gen) + vmin = min(S_t.max() - 80, S_g.max() - 80) # ~80 dB dynamic range + vmax = max(S_t.max(), S_g.max()) + + # figure ----------------------------------------------------------------------------- + fig, axes = plt.subplots(2, 2, figsize=(14, 7)) + + # shared y-scale across the two waveform panels, so amplitudes are directly comparable + axes[0, 0].plot(t_ms, true, color="tab:orange", lw=0.8) + axes[0, 0].set_title("Target waveform (detrended)") + axes[0, 1].plot(t_ms, x_gen, color="tab:blue", lw=0.8) + axes[0, 1].set_title("Autonomous waveform") + ymax = max(np.abs(true).max(), np.abs(x_gen).max()) * 1.05 + for ax in axes[0]: + ax.set_xlabel("time (ms)") + ax.set_ylabel("a.u.") + ax.set_ylim(-ymax, ymax) + + for ax, (f, tt, S), title in [ + (axes[1, 0], (f_t, tt_t, S_t), "Target spectrogram"), + (axes[1, 1], (f_g, tt_g, S_g), "Autonomous spectrogram"), + ]: + pcm = ax.pcolormesh(tt, f, S, shading="auto", vmin=vmin, vmax=vmax, cmap="magma") + ax.set_ylim(0, args.fmax) + ax.set_title(title) + ax.set_xlabel("time (ms)") + ax.set_ylabel("frequency (Hz)") + fig.colorbar(pcm, ax=ax, label="power (dB)") + + fig.suptitle( + f"Arneodo autonomous integration | start {args.pre_onset_ms:.0f} ms pre-onset | " + f"{n * dt * 1e3:.0f} ms | sr {sr} Hz", + y=1.00, + ) + fig.tight_layout() + + out_path = os.path.join(out_dir, "waveforms_and_spectrograms.png") + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved figure to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/plot_drives.py b/examples/plot_drives.py new file mode 100644 index 0000000..208f6d0 --- /dev/null +++ b/examples/plot_drives.py @@ -0,0 +1,163 @@ +""" +Compare a trained `ArneodoOuroboros`'s learned drives against the control inputs used to +generate the Mindlin/gabo data -- by unpacking the algebra between the two ODEs. + +The two right-hand sides, grouped by monomial in (x, xdot): + + generator ẍ = -(delta*D) - (eps1+eps2*K) x + (beta1+beta2*P) xdot - C x^2 xdot + model ẍ = g_phys^2 alpha + g_phys^2 beta x - g_phys delta_m xdot + + g_phys^2 x^2 - g_phys^2 x^3 - g_phys x xdot - g_phys x^2 xdot + +(g_phys = gamma / tau is the physical time-scaling; gamma is the model's rescaled scalar.) + +Equating coefficients of like monomials, the model drives map to the true control inputs by +INVERTING the generator's affine coefficient maps: + + D_implied = -(g_phys^2 alpha) / delta + K_implied = (-(g_phys^2 beta) - eps1) / eps2 + P_implied = ((-g_phys delta_m) - beta1) / beta2 + +so a faithful model should have D_implied ~ D, K_implied ~ K, P_implied ~ P. This is NOT a +1-to-1 identity of the raw drives -- it's the affine relation between matching ODE +coefficients. Two reasons the match is still imperfect: (1) the model has x^2, x^3, x*xdot +terms (all tied to one gamma) and an x^2*xdot coeff -g_phys != -C that the generator lacks, +so it is structurally a different ODE; (2) per timestep ẍ is one equation in three unknown +drives, so the split is unidentifiable unless the drives are slow -- they are not (they +oscillate at the carrier), so we compare their low-pass (slow) component. + +Run from the repo root: + python -m examples.plot_drives --out-dir ./arneodo_run +""" + +import argparse +import os + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np +import torch +from scipy.io import wavfile +from scipy.signal import butter, sosfiltfilt + +from utils import deriv_approx_dy +from train.train import load_model +from data.generate_data_gabo import eps1, eps2, beta1, beta2, C, delta as delta_gen + +plt.rcParams["text.usetex"] = False + + +def lowpass(x, cutoff_hz, fs, order=4): + """zero-phase Butterworth low-pass; extracts the slow (control-rate) component.""" + sos = butter(order, cutoff_hz, btype="low", fs=fs, output="sos") + return sosfiltfilt(sos, x) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--out-dir", default="./arneodo_run") + parser.add_argument("--voc", type=int, default=0, help="which gabo_artificial_ file") + parser.add_argument("--vocalization", type=int, default=0, help="which onset/offset pair") + parser.add_argument("--pad-ms", type=float, default=20.0, help="window padding each side (ms)") + parser.add_argument("--cutoff-hz", type=float, default=50.0, + help="low-pass cutoff for the learned drives' slow component") + args = parser.parse_args() + + out_dir = os.path.abspath(args.out_dir) + data_dir = os.path.join(out_dir, "gabo_data") + model_dir = os.path.join(out_dir, "model", "arneodo") + tag = f"gabo_artificial_{args.voc}" + + model, _, _, epoch = load_model(model_dir) + model.eval() + g_phys = float(model.gamma) / model.tau + print(f"loaded {type(model).__name__} (epoch {epoch}); g_phys = gamma/tau = {g_phys:.1f}") + print( + "model's SPURIOUS / mismatched constant coefficients (generator has 0 / -C):\n" + f" x^2 coeff = +g_phys^2 = {g_phys**2:+.3e} (generator: 0)\n" + f" x^3 coeff = -g_phys^2 = {-g_phys**2:+.3e} (generator: 0)\n" + f" x*xd coeff = -g_phys = {-g_phys:+.3e} (generator: 0)\n" + f" x^2xd coeff = -g_phys = {-g_phys:+.3e} (generator -C = {-C:+.3e})" + ) + + # data: audio (x), the three control inputs, and the vocalization interval ----------- + sr, audio_full = wavfile.read(os.path.join(data_dir, f"{tag}.wav")) + audio_full = audio_full.astype(np.float64) + pkd = np.loadtxt(os.path.join(data_dir, f"{tag}_PKD.txt")) # cols: P, K, D + P_full, K_full, D_full = pkd[:, 0], pkd[:, 1], pkd[:, 2] + onoffs = np.atleast_2d(np.loadtxt(os.path.join(data_dir, f"{tag}.txt"))) + on_s, off_s = onoffs[args.vocalization] + dt = 1 / sr + + pad = int(round(args.pad_ms / 1e3 * sr)) + on_i, off_i = int(round(on_s * sr)), int(round(off_s * sr)) + a, b = max(0, on_i - pad), min(len(audio_full), off_i + pad) + t_s = np.arange(a, b) * dt + audio = audio_full[a:b] + P, K, D = P_full[a:b], K_full[a:b], D_full[a:b] + + # learned drives over the same window ------------------------------------------------ + audio_t = torch.from_numpy(audio[None, :, None]).to(torch.float32).to("cuda") + dy_t = torch.from_numpy(deriv_approx_dy(audio[None, :, None])).to(torch.float32).to("cuda") + with torch.no_grad(): + alpha, beta, delta_m, _ = model.get_funcs(audio_t, dy_t, dt) + alpha = alpha.detach().cpu().numpy().squeeze() + beta = beta.detach().cpu().numpy().squeeze() + delta_m = delta_m.detach().cpu().numpy().squeeze() + + # unpack the algebra: model-implied control inputs (true control-input units) -------- + g2 = g_phys**2 + D_impl = -(g2 * alpha) / delta_gen + K_impl = (-(g2 * beta) - eps1) / eps2 + P_impl = ((-g_phys * delta_m) - beta1) / beta2 + + # the implied inputs oscillate at the carrier (drives are unconstrained per-sample); + # compare their slow component to the (slow) true control inputs + D_impl_lp = lowpass(D_impl, args.cutoff_hz, sr) + K_impl_lp = lowpass(K_impl, args.cutoff_hz, sr) + P_impl_lp = lowpass(P_impl, args.cutoff_hz, sr) + + voc = slice(on_i - a, off_i - a) + + def corr(m, d): + m, d = m[voc], d[voc] + if m.std() < 1e-12 or d.std() < 1e-12: + return float("nan") + return float(np.corrcoef(m, d)[0, 1]) + + panels = [ + ("D (dynamics)", D, D_impl, D_impl_lp, "constant forcing (alpha)"), + ("K (tension)", K, K_impl, K_impl_lp, "restoring / frequency (beta)"), + ("P (pressure)", P, P_impl, P_impl_lp, "linear damping (delta)"), + ] + + fig, axes = plt.subplots(3, 1, figsize=(12, 9), sharex=True) + for ax, (name, true, impl_raw, impl_lp, slot) in zip(axes, panels): + r = corr(impl_lp, true) + ax.plot(t_s, impl_raw, color="tab:blue", lw=0.5, alpha=0.12) + ax.plot(t_s, true, color="tab:orange", lw=1.8, label=f"true {name}") + ax.plot(t_s, impl_lp, color="tab:blue", lw=1.6, + label=f"model-implied {name} (<{args.cutoff_hz:.0f} Hz)") + ax.axvline(on_s, color="0.6", ls="--", lw=0.8) + ax.axvline(off_s, color="0.6", ls="--", lw=0.8) + ax.set_ylabel(name) + ax.set_title(f"{slot}: model-implied vs true {name} (r over vocalization = {r:+.2f})") + ax.legend(loc="upper right", fontsize=8) + axes[-1].set_xlabel("time (s)") + + fig.suptitle( + f"Model-implied vs true control inputs (affine maps inverted) | {tag} | " + f"vocalization {args.vocalization}", + y=1.00, + ) + fig.tight_layout() + out_path = os.path.join(out_dir, "learned_vs_true_drives.png") + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved figure to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/model/model.py b/model/model.py index c4d1d43..918b85f 100644 --- a/model/model.py +++ b/model/model.py @@ -1,5 +1,8 @@ +import math + import torch from torch import nn +import torch.nn.functional as F from mambapy.mamba import Mamba, MambaConfig from model.model_utils import smooth from typing import Tuple, Union @@ -377,3 +380,245 @@ def funcs_by_step( gammas = smooth(gammas[None, :, None], smooth_len).squeeze() return yhat, omegas, gammas, kernel, weights + + +class ArneodoOuroboros(nn.Module): + """ + Variant of `Ouroboros` that parameterizes the second derivative using the + biomechanical syrinx ODE of Arneodo et al. 2021 (Current Biology) instead of a + full polynomial kernel. In physical time the (paper) model reads: + + dx/dt = y + dy/dt = gamma^2 alpha + gamma^2 beta x + gamma^2 x^2 + - gamma^2 x^3 - gamma x y - gamma x^2 y + + where x is the labial displacement, gamma is a (constant) time-scaling factor, + and alpha, beta are the two bird-controlled parameters (sub-syringeal pressure + and syringeal-muscle tension). + + This implementation additionally includes a *linear* damping term -gamma * delta * y + (delta(t) a third control series), which the strict paper polynomial lacks. It lets the + model represent sources with a linear damping/anti-damping term -- e.g. the Mindlin/gabo + data generator's `B * xdot` (van-der-Pol-style) term. delta can be either sign; delta < 0 + is anti-damping (pumps energy, sustaining oscillation). Three parallel Mamba encoders + output alpha(t), beta(t), delta(t); gamma is a single learned positive scalar + (`self.gamma`), as in the paper where gamma is constant. delta's head is zero-initialized, + so the model *starts* at delta == 0 (the strict paper form) and only learns linear + damping if it reduces the loss -- the extended form is a strict superset of the paper one. + + As in `Ouroboros`, the model works in rescaled time s = t/tau: `forward` scales the + input first derivative to x' = dx/ds = (tau/dt) * dxdt, and the training target is + d2x/ds2 = tau^2 * d2x/dt2. Substituting s = t/tau into the ODE above (the linear term + -gamma*delta*y rescales to -g*delta*x', matching the other damping terms) leaves the + form identical with a rescaled scalar g = tau * gamma, so this is a drop-in replacement + for the polynomial RHS and the existing training target scaling is unchanged: + + d2x/ds2 = g^2 alpha + g^2 beta x + g^2 x^2 - g^2 x^3 - g (delta + x + x^2) x' + + `self.gamma` is exactly this rescaled g. (With tau = dt, g = dt * gamma_phys, an O(1) + learnable scalar.) + """ + + def __init__( + self, + d_data: int, + n_layers: int = 2, + d_state: int = 16, + d_conv: int = 4, + expand_factor: int = 1, + device: str = "cuda", + tau: float = 1 / 10000, + smooth_len: float = 0.001, + gamma_init: float = 1.0, + ): + + super().__init__() + + self.device = device + + # we stack x and (scaled) dx on the time dimension, so d_model = 2 * d_data + alphaConfig = MambaConfig( + d_model=2 * d_data, + n_layers=n_layers, + d_state=d_state, + d_conv=d_conv, + expand_factor=expand_factor, + ) + betaConfig = MambaConfig( + d_model=2 * d_data, + n_layers=n_layers, + d_state=d_state, + d_conv=d_conv, + expand_factor=expand_factor, + ) + deltaConfig = MambaConfig( + d_model=2 * d_data, + n_layers=n_layers, + d_state=d_state, + d_conv=d_conv, + expand_factor=expand_factor, + ) + + self.alpha_mamba = Mamba(alphaConfig).to(device) + self.beta_mamba = Mamba(betaConfig).to(device) + self.delta_mamba = Mamba(deltaConfig).to(device) + + self.alpha_net = nn.Linear( + in_features=2 * d_data, out_features=d_data, device=device + ) # output unconstrained: alpha can be either sign + self.beta_net = nn.Linear( + in_features=2 * d_data, out_features=d_data, device=device + ) # output unconstrained: beta can be either sign + self.delta_net = nn.Linear( + in_features=2 * d_data, out_features=d_data, device=device + ) # output unconstrained: delta can be either sign (negative => anti-damping) + + # Zero-initialize the control heads so the model starts at alpha = beta = delta = 0, + # i.e. yhat = g^2 (x^2 - x^3) - g (x + x^2) x' (bounded, structural terms only -- + # exactly the strict-paper-form dynamics). Without this the heads emit O(1) forcing + # that is ~10-100x the target second-derivative scale, which makes the loss explode + # and conditioning poor at the start of training. Because delta is also zero-init, + # the extended (linear-damping) form is a strict superset of the paper form: the + # model starts strict and learns linear damping only if it reduces the loss. + for net in (self.alpha_net, self.beta_net, self.delta_net): + nn.init.zeros_(net.weight) + nn.init.zeros_(net.bias) + + # gamma (rescaled time-scaling factor) is a single positive scalar, learned as + # softplus(log_gamma). Initialize log_gamma so that softplus(log_gamma) ~ gamma_init. + inv_softplus = math.log(math.expm1(gamma_init)) # inverse of softplus + self.log_gamma = nn.Parameter(torch.tensor(float(inv_softplus), device=device)) + + self.tau = tau + self.smooth_len = smooth_len + self.names = [r"$\alpha$", r"$\beta$", r"$\delta$", r"$\gamma$"] + + @property + def gamma(self) -> torch.FloatTensor: + """rescaled, strictly-positive time-scaling scalar g = softplus(log_gamma).""" + return F.softplus(self.log_gamma) + + def _encode( + self, + x: torch.FloatTensor, + dxdt: torch.FloatTensor, + dt: float, + smoothing: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + shared input prep + encoders. Returns (alpha, beta, delta, z) where z = [x, x'] is + the state in rescaled time (x' = dx/ds = (tau/dt) * dxdt). + """ + + # scale first derivative to rescaled time: x' = dx/ds = (tau/dt) * dxdt. + # use out-of-place ops so we never mutate the caller's tensor. + dxdt = dxdt * (self.tau / dt) + + z = torch.cat([x, dxdt], dim=-1) + L = z.shape[1] + + # feed the state and its time-reversed copy to each encoder, then keep the + # forward-time half (matches Ouroboros.forward) + x_in = torch.cat([torch.flip(z, [1]), z], dim=1) + + alphaControl = self.alpha_mamba(x_in)[:, L:, :] + betaControl = self.beta_mamba(x_in)[:, L:, :] + deltaControl = self.delta_mamba(x_in)[:, L:, :] + + alpha = self.alpha_net(alphaControl) + beta = self.beta_net(betaControl) + delta = self.delta_net(deltaControl) + + if smoothing: + smooth_len = int(round(self.smooth_len / dt)) + alpha = smooth(alpha, smooth_len) + beta = smooth(beta, smooth_len) + delta = smooth(delta, smooth_len) + + return alpha, beta, delta, z + + def _rhs( + self, + alpha: torch.FloatTensor, + beta: torch.FloatTensor, + delta: torch.FloatTensor, + z: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + the Arneodo RHS in rescaled time, given alpha, beta, delta and state z = [x, x']. + returns the predicted second derivative d2x/ds2. The damping is + -g (delta + x + x^2) x': the (delta) term is the linear-damping extension, the + (x + x^2) terms are the strict paper nonlinear damping. + """ + g = self.gamma + x = z[:, :, :1] + xp = z[:, :, 1:] + g2 = g * g + return ( + g2 * alpha + + g2 * beta * x + + g2 * x**2 + - g2 * x**3 + - g * delta * xp + - g * x * xp + - g * x**2 * xp + ) + + def forward( + self, + x: torch.FloatTensor, + dxdt: torch.FloatTensor, + dt: float, + smoothing: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """ + predicts the second derivative at time t via the Arneodo ODE. + + inputs + ------ + - x: cleaned audio segment (B, L, d_data) + - dxdt: first derivative estimate (unscaled; scaled internally by tau/dt) + - dt: sample interval + - smoothing: whether to smooth alpha, beta. We do not, but you can + + outputs + ------ + - yhat: predicted second derivative, scaled by tau^2 (i.e. d2x/ds2) + - alpha: the alpha(t) time series. Returned in the second slot (where the + polynomial model returns kernel weights) so the train loop signature + matches; it is NOT regularized (train with reg_weights=False). + """ + + alpha, beta, delta, z = self._encode(x, dxdt, dt, smoothing=smoothing) + yhat = self._rhs(alpha, beta, delta, z) + return yhat, alpha + + def get_funcs( + self, + x: torch.FloatTensor, + dxdt: torch.FloatTensor, + dt: float, + smoothing: bool = False, + ) -> Tuple[ + torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor + ]: + """ + given data, returns the learned model functions for the Arneodo parameterization. + + inputs + ------ + - x: audio segment (B, L, d_data) + - dxdt: first derivative estimate (unscaled) + - dt: sampling timestep + - smoothing: whether to smooth alpha, beta, delta + + returns + ------ + - alpha: alpha(t) control time series (B, L, d_data) + - beta: beta(t) control time series (B, L, d_data) + - delta: delta(t) linear-damping control time series (B, L, d_data) + - gamma: the learned scalar g (rescaled time-scaling factor) + """ + + alpha, beta, delta, _ = self._encode(x, dxdt, dt, smoothing=smoothing) + return alpha, beta, delta, self.gamma diff --git a/train/eval.py b/train/eval.py index 07b590a..80422e6 100644 --- a/train/eval.py +++ b/train/eval.py @@ -222,6 +222,117 @@ def dz_hat(t, z): return yhat +def integrate_model_autonomous( + model: torch.nn.Module, + audio: np.ndarray, + dt: float, + method: str = "rk4", + detrend: bool = True, + verbose: bool = True, +) -> np.ndarray: + """ + fully autonomous (closed-loop) integration of an `ArneodoOuroboros` model. + + Unlike `integrate_model_d2`, which replays the model's predicted second derivative + evaluated at the *data* points, this integrates the biomechanical syrinx ODE while + feeding the last generated state (x, x') back into the right-hand side. The state is + therefore generated entirely by the integrator -- only the control time series + alpha(t), beta(t) (produced once by the encoder from `audio`, as in the paper's + neurally-driven synthesis), the scalar gamma, and the initial condition come from + outside the integrator. + + We integrate in the model's rescaled time s = t / tau, where the state is z = [x, x'] + with x' = dx/ds, gamma is the learned (rescaled) scalar `model.gamma`, and + + dx/ds = x' + dx'/ds = g^2 a + g^2 b x + g^2 x^2 - g^2 x^3 - g (d + x + x^2) x' + + with g = gamma, a = alpha(s), b = beta(s), d = delta(s) (the linear-damping series). + + inputs + ----- + - model: a trained ArneodoOuroboros + - audio: 1-D audio segment used to produce alpha(t), beta(t) and the IC + - dt: audio sampling spacing (seconds) + - method: integration method (passed to torchdiffeq) + - detrend: whether to low-pass detrend the generated waveform (as in `correct`) + - verbose: print integration progress + + returns + ----- + - the autonomously generated waveform x, sampled at the same points as `audio` + """ + + L = len(audio) + t_steps = np.arange(0, L * dt + dt / 2, dt)[:L] + s_steps = t_steps / model.tau # rescaled time s = t / tau + + audio_3d = audio[None, :, None] + dy = deriv_approx_dy(audio_3d) # per-sample first derivative dx/dn + + audio_t = torch.from_numpy(audio_3d).to(torch.float32).to("cuda") + dy_t = torch.from_numpy(dy).to(torch.float32).to("cuda") + + with torch.no_grad(): + alpha, beta, delta, gamma = model.get_funcs(audio_t, dy_t, dt) + + alpha = alpha.detach().cpu().numpy().squeeze() + beta = beta.detach().cpu().numpy().squeeze() + delta = delta.detach().cpu().numpy().squeeze() + gamma = float(gamma.detach().cpu().numpy()) + g2 = gamma**2 + + # control parameters as smooth functions of rescaled time + alpha_interp = make_interp_spline(s_steps, alpha) + beta_interp = make_interp_spline(s_steps, beta) + delta_interp = make_interp_spline(s_steps, delta) + + # initial condition in rescaled time: x(0) and x'(0) = dx/ds = (tau/dt) * dx/dn + x0 = float(audio[0]) + xp0 = (model.tau / dt) * float(dy[0, 0, 0]) + ic = torch.tensor([x0, xp0], dtype=torch.float32, device="cuda") + + def dz_hat(s, z): + if verbose: + print( + f"{(s - s_steps[0]) / (s_steps[-1] - s_steps[0]) * 100:0.3f}%,", + end="\r", + ) + + s_np = s.detach().cpu().numpy() + a = float(alpha_interp(s_np)) + b = float(beta_interp(s_np)) + d = float(delta_interp(s_np)) + + x = z[0] + xp = z[1] + + dx = xp + dxp = ( + g2 * a + + g2 * b * x + + g2 * x**2 + - g2 * x**3 + - gamma * d * xp + - gamma * x * xp + - gamma * x**2 * xp + ) + + return torch.hstack([dx.reshape(1), dxp.reshape(1)]) + + eval_times = torch.from_numpy(s_steps).to(ic.device) + + with torch.no_grad(): + sol = odeint_adjoint( + dz_hat, ic, eval_times, adjoint_params=(), method=method, options=dict() + ).transpose(0, 1) + + x_gen = sol[0].detach().cpu().numpy().squeeze() + if detrend: + x_gen = correct(x_gen) + return x_gen + + def eval_model_error( dls: dict, model: torch.nn.Module, dt: float, comparison: str = "val" ) -> tuple[tuple[float, float], tuple[float, float], tuple[np.ndarray, np.ndarray]]: diff --git a/train/model_cv.py b/train/model_cv.py index 38f1f32..01aad2d 100644 --- a/train/model_cv.py +++ b/train/model_cv.py @@ -1,6 +1,6 @@ from train.train import train, save_model, load_model from model.kernels import fullPolyModule -from model.model import Ouroboros +from model.model import Ouroboros, ArneodoOuroboros from utils import sse from visualization.model_vis import loss_plot from train.eval import eval_model_error @@ -212,3 +212,113 @@ def model_cv_lambdas( data_df.to_csv(os.path.join(model_path, "cv_errs.csv")) return full_model_poly + + +def train_arneodo( + dls: dict, + dt: float, + n_epochs: int = 100, + lr: float = 1e-3, + expand_factor: int = 10, + n_layers: int = 4, + d_state: int = 1, + d_conv: int = 4, + tau: float = 1 / 1000, + smooth_len: float = 0.001, + model_path: str = "", + save_freq: int = 5, +) -> torch.nn.Module: + """ + trains a single `ArneodoOuroboros` model (the biomechanical syrinx parameterization). + + Unlike `model_cv_lambdas`, there is no regularization-strength cross-validation: the + Arneodo RHS has no polynomial kernel weights to penalize, so we fit one model with + `reg_weights=False`. The trained model is saved and returned in memory. + + inputs + ----- + - dls: dictionary of dataloaders (train / val / test) + - dt: spacing between audio samples, in seconds + - n_epochs: number of passes through the training data + - lr: learning rate + - expand_factor: expansion from audio to mamba input + - n_layers: number of mamba layers in each encoder + - d_state: internal state size of mamba model + - d_conv: length of internal convolution of mamba model + - tau: timescale for model decoder, to de-dimensionalize the data + - smooth_len: smoothing length for model functions (not used in training) + - model_path: place to save the model and training artifacts + - save_freq: how often (in epochs) to checkpoint the model + + returns + ----- + - the trained ArneodoOuroboros + """ + + model_info = { + "n layers": n_layers, + "d state": d_state, + "d conv": d_conv, + "expand factor": expand_factor, + } + + model = ArneodoOuroboros( + d_data=1, + n_layers=n_layers, + d_state=d_state, + d_conv=d_conv, + expand_factor=expand_factor, + tau=tau, + smooth_len=smooth_len, + ) + + opt = Adam(model.parameters(), lr=lr) + scheduler = ReduceLROnPlateau( + opt, factor=0.5, patience=max(n_epochs // 25, 2), min_lr=1e-10 + ) + + run_dir = os.path.join(model_path, "arneodo") + os.makedirs(run_dir, exist_ok=True) + save_loc = os.path.join(run_dir, f"checkpoint_{n_epochs}.tar") + + tl, vl, model, opt = train( + model, + opt, + loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), + loaders=dls, + scheduler=scheduler, + nEpochs=n_epochs, + val_freq=1, + runDir=run_dir, + dt=dt, + vis_freq=max(n_epochs // 10, 1), + smoothing=False, + reg_weights=False, # no kernel weights to regularize in the Arneodo RHS + start_epoch=0, + save_freq=save_freq, + model_info=model_info, + ) + + loss_plot(tl, vl, save_loc=run_dir, show=False) + + save_model( + model, + opt, + save_loc, + n_layers=n_layers, + d_state=d_state, + expand_factor=expand_factor, + d_conv=d_conv, + ) + + model.eval() + with torch.no_grad(): + (train_mu, test_mu), (train_sd, test_sd), _ = eval_model_error( + dls, model, dt=dt, comparison="test" + ) + print( + f"Arneodo model R2 -- train: {train_mu:.4f} +- {train_sd:.4f}, " + f"test: {test_mu:.4f} +- {test_sd:.4f}" + ) + + return model diff --git a/train/train.py b/train/train.py index ef2f768..09216ca 100644 --- a/train/train.py +++ b/train/train.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt import os import glob -from model.model import Ouroboros +from model.model import Ouroboros, ArneodoOuroboros from model.kernels import fullPolyModule from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -55,6 +55,8 @@ def save_model( ordered_saves = [current_saves[o] for o in save_order] for ii in range(len(current_saves) - max_saved + 1): os.remove(ordered_saves[ii]) + # "poly" models carry a kernel module; the Arneodo parameterization does not. + parameterization = "poly" if hasattr(model, "kernel") else "arneodo" sd = { "ouroboros": model.state_dict(), "opt": opt.state_dict(), @@ -64,10 +66,11 @@ def save_model( "d_state": d_state, "d_conv": d_conv, "expand_factor": expand_factor, + "parameterization": parameterization, } try: sd["n_kernel"] = model.kernel.nTerms - except KeyError: + except (KeyError, AttributeError): pass torch.save(sd, location) @@ -111,19 +114,9 @@ def load_model( d_state = 1 d_conv = 4 expand_factor = 4 - try: - # since this is a trained model and we only use lambda during training, i set it to 1 here... - # but probably should have saved it. oh well! we set to 1 for compatibility with all my saves. - kernel = fullPolyModule( - nTerms=sd["n_kernel"], - device="cuda", - x_dim=1, - z_dim=2, - activation=lambda x: x, - lam=1, - ) - - model = Ouroboros( + parameterization = sd.get("parameterization", "poly") + if parameterization == "arneodo": + model = ArneodoOuroboros( d_data=1, n_layers=n_layers, d_state=d_state, @@ -131,11 +124,33 @@ def load_model( expand_factor=expand_factor, tau=sd["tau"], smooth_len=sd["smooth_len"], - kernel=kernel, ) - except: - print("no kernel in savefile!") - raise + else: + try: + # since this is a trained model and we only use lambda during training, i set it to 1 here... + # but probably should have saved it. oh well! we set to 1 for compatibility with all my saves. + kernel = fullPolyModule( + nTerms=sd["n_kernel"], + device="cuda", + x_dim=1, + z_dim=2, + activation=lambda x: x, + lam=1, + ) + + model = Ouroboros( + d_data=1, + n_layers=n_layers, + d_state=d_state, + d_conv=d_conv, + expand_factor=expand_factor, + tau=sd["tau"], + smooth_len=sd["smooth_len"], + kernel=kernel, + ) + except: + print("no kernel in savefile!") + raise print(f"model tau: {model.tau}") opt = Adam(model.parameters(), lr=1e-3) @@ -289,7 +304,10 @@ def px(x): train_loss = loss_fn(y, yhat[:, :L, :]) - #l = loss + # default objective is the data loss; the polynomial parameterization adds a + # weight-complexity penalty when reg_weights=True. Parameterizations without + # kernel weights (e.g. ArneodoOuroboros) train with reg_weights=False. + total_loss = train_loss if reg_weights: B, L, P, P = weights.shape lam_mat = torch.arange( diff --git a/train/train_model.py b/train/train_model.py index 10691c2..3113222 100644 --- a/train/train_model.py +++ b/train/train_model.py @@ -1,7 +1,7 @@ from data.load_data import get_segmented_audio from data.data_utils import get_loaders -from train.model_cv import model_cv_lambdas +from train.model_cv import model_cv_lambdas, train_arneodo from typing import Union import os @@ -22,6 +22,7 @@ def train_model( batch_size: int = 32, n_epochs: int = 100, save_freq: int = 5, + parameterization: str = "poly", ) -> torch.nn.Module: """ function for training a model. takes audio from @@ -41,6 +42,9 @@ def train_model( batch_size: batch size during training n_epochs: max number of passes through the data during training save_freq: how often (in epochs) we want to checkpoint model + parameterization: "poly" for the full-polynomial Ouroboros (with lambda + cross-validation), or "arneodo" for the Arneodo 2021 syrinx ODE + parameterization (single fit, no regularization CV) returns -------- best model after hyperparameter cross-validation @@ -86,20 +90,35 @@ def train_model( dt=dt, ) - best_model = model_cv_lambdas( - dls=dataloaders, - dt=dt, - n_epochs=n_epochs, - lr=1e-3, - n_kernels=15, - expand_factor=10, - n_layers=3, - d_state=1, - d_conv=4, - tau=dt, - model_path=model_dir, - save_freq=save_freq, - ) + if parameterization == "arneodo": + best_model = train_arneodo( + dls=dataloaders, + dt=dt, + n_epochs=n_epochs, + lr=1e-3, + expand_factor=10, + n_layers=3, + d_state=1, + d_conv=4, + tau=dt, + model_path=model_dir, + save_freq=save_freq, + ) + else: + best_model = model_cv_lambdas( + dls=dataloaders, + dt=dt, + n_epochs=n_epochs, + lr=1e-3, + n_kernels=15, + expand_factor=10, + n_layers=3, + d_state=1, + d_conv=4, + tau=dt, + model_path=model_dir, + save_freq=save_freq, + ) return best_model From dc70ff2070bfa5501affee0c6d26e29bbc4135bd Mon Sep 17 00:00:00 2001 From: John Pearson Date: Wed, 27 May 2026 13:27:33 -0400 Subject: [PATCH 02/15] Expose model capacity in train API; add parallel gen + monitored training scripts Thread lr/n_layers/d_state/d_conv/expand_factor through train_model into both train_arneodo and model_cv_lambdas (defaults unchanged, so existing behavior is preserved). Bumping d_state from 1 to 4 on a larger dataset reaches test R^2 > 0.98 with the Arneodo parameterization. - train/train_model.py: capacity kwargs forwarded to both parameterizations. - examples/gen_gabo.py: single-shard generator for parallel data generation. - examples/train_arneodo_big.py: configurable, monitored training (segmented train + per-K-epoch R^2 eval + early-stop at a target R^2). - .gitignore: ignore the example run-output dirs. Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 3 + examples/gen_gabo.py | 17 +++++ examples/train_arneodo_big.py | 121 ++++++++++++++++++++++++++++++++++ train/train_model.py | 33 +++++++--- 4 files changed, 164 insertions(+), 10 deletions(-) create mode 100644 examples/gen_gabo.py create mode 100644 examples/train_arneodo_big.py diff --git a/.gitignore b/.gitignore index 43eb49e..7f6fee9 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,6 @@ cython_debug/ # example run outputs (generated data, checkpoints, plots) arneodo_run/ +arneodo_big/ +arneodo_feas/ +data500/ diff --git a/examples/gen_gabo.py b/examples/gen_gabo.py new file mode 100644 index 0000000..e4a10fc --- /dev/null +++ b/examples/gen_gabo.py @@ -0,0 +1,17 @@ +"""Generate a shard of synthetic Mindlin/gabo vocalizations (for parallel data gen).""" +import argparse +import jax.random as jr +from data.generate_data_gabo import generate_vocal_dataset + +p = argparse.ArgumentParser() +p.add_argument("--out-dir", required=True) +p.add_argument("--n-vocs", type=int, required=True) +p.add_argument("--seed", type=int, required=True) +p.add_argument("--noise-sd", type=float, default=0.0002) +a = p.parse_args() + +generate_vocal_dataset( + jr.PRNGKey(a.seed), n_vocs=a.n_vocs, noise_sd=a.noise_sd, + audio_loc=a.out_dir, seg_loc=a.out_dir, func_loc=a.out_dir, +) +print(f"shard {a.out_dir}: wrote {a.n_vocs} vocs (seed {a.seed})") diff --git a/examples/train_arneodo_big.py b/examples/train_arneodo_big.py new file mode 100644 index 0000000..f6e68c6 --- /dev/null +++ b/examples/train_arneodo_big.py @@ -0,0 +1,121 @@ +""" +Configurable, monitored training of an ArneodoOuroboros toward a target R^2. + +Gathers chunks from one or more gabo data directories, builds dataloaders, and trains in +segments -- evaluating train/test R^2 after each segment, checkpointing, and stopping early +once the target test R^2 is reached. Capacity (n_layers, d_state, d_conv, expand_factor) is +exposed so we can give the model enough capacity to fit the finite-difference target. + +Run from the repo root, e.g.: + python -m examples.train_arneodo_big --data-glob './data500/gabo_*' \ + --out-dir ./arneodo_big --d-state 16 --n-layers 4 --epochs 200 --seg 10 \ + --batch-size 32 --target-r2 0.97 +""" + +import argparse +import glob +import os + +import numpy as np +import torch +from torch.optim import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from data.load_data import get_segmented_audio +from data.data_utils import get_loaders +from model.model import ArneodoOuroboros +from train.train import train, save_model +from train.eval import eval_model_error +from utils import sse + + +def gather_loaders(data_dirs, max_vocs, context_len, batch_size, seed, n_jobs): + chunks = [] + sr = None + per_dir = max(1, max_vocs // len(data_dirs)) + for d in data_dirs: + audio, sr = get_segmented_audio( + d, d, max_vocs=per_dir, context_len=context_len, seed=seed, + training=True, extend=True, shuffle_order=True, + ) + chunks += audio + print(f"gathered {len(chunks)} chunks from {len(data_dirs)} dir(s); sr={sr}") + dt = 1 / sr + dls = get_loaders( + np.stack(chunks, axis=0), num_workers=n_jobs, batch_size=batch_size, + train_size=0.6, cv=True, seed=seed, dt=dt, + ) + return dls, dt + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data-glob", required=True, help="glob matching gabo data dir(s)") + p.add_argument("--out-dir", default="./arneodo_big") + p.add_argument("--max-vocs", type=int, default=100000) + p.add_argument("--context-len", type=float, default=0.25) + p.add_argument("--batch-size", type=int, default=32) + p.add_argument("--epochs", type=int, default=200) + p.add_argument("--seg", type=int, default=10, help="eval/checkpoint every this many epochs") + p.add_argument("--target-r2", type=float, default=0.97) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--n-layers", type=int, default=4) + p.add_argument("--d-state", type=int, default=16) + p.add_argument("--d-conv", type=int, default=4) + p.add_argument("--expand-factor", type=int, default=10) + p.add_argument("--seed", type=int, default=1234) + p.add_argument("--n-jobs", type=int, default=4) + args = p.parse_args() + + out_dir = os.path.abspath(args.out_dir) + run_dir = os.path.join(out_dir, "arneodo") + os.makedirs(run_dir, exist_ok=True) + + data_dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + assert data_dirs, f"no dirs matched {args.data_glob}" + + dls, dt = gather_loaders( + data_dirs, args.max_vocs, args.context_len, args.batch_size, args.seed, args.n_jobs + ) + + model = ArneodoOuroboros( + d_data=1, n_layers=args.n_layers, d_state=args.d_state, d_conv=args.d_conv, + expand_factor=args.expand_factor, tau=dt, + ) + n_params = sum(q.numel() for q in model.parameters()) + print(f"model: n_layers={args.n_layers} d_state={args.d_state} d_conv={args.d_conv} " + f"expand={args.expand_factor} -> {n_params} params; tau={dt:.2e}") + + opt = Adam(model.parameters(), lr=args.lr) + scheduler = ReduceLROnPlateau(opt, factor=0.5, patience=max(args.seg, 3), min_lr=1e-10) + model_info = {"n layers": args.n_layers, "d state": args.d_state, + "d conv": args.d_conv, "expand factor": args.expand_factor} + + best = -np.inf + for start in range(0, args.epochs, args.seg): + end = min(args.epochs, start + args.seg) + train( + model, opt, + loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), + loaders=dls, scheduler=scheduler, nEpochs=end, val_freq=1, runDir=run_dir, + dt=dt, vis_freq=0, smoothing=False, reg_weights=False, start_epoch=start, + save_freq=max(args.seg, 1), model_info=model_info, + ) + model.eval() + with torch.no_grad(): + (tr, te), (trsd, tesd), _ = eval_model_error(dls, model, dt=dt, comparison="test") + print(f"[epoch {end}] train R2={tr:.4f}+-{trsd:.4f} test R2={te:.4f}+-{tesd:.4f}", + flush=True) + save_model(model, opt, os.path.join(run_dir, f"checkpoint_{end}.tar"), + n_layers=args.n_layers, d_state=args.d_state, + d_conv=args.d_conv, expand_factor=args.expand_factor) + best = max(best, te) + if te >= args.target_r2: + print(f"reached target test R2 {args.target_r2} at epoch {end}", flush=True) + break + + print(f"DONE. best test R2 = {best:.4f}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/train/train_model.py b/train/train_model.py index 3113222..9d06bb8 100644 --- a/train/train_model.py +++ b/train/train_model.py @@ -23,6 +23,11 @@ def train_model( n_epochs: int = 100, save_freq: int = 5, parameterization: str = "poly", + lr: float = 1e-3, + n_layers: int = 3, + d_state: int = 1, + d_conv: int = 4, + expand_factor: int = 10, ) -> torch.nn.Module: """ function for training a model. takes audio from @@ -45,6 +50,14 @@ def train_model( parameterization: "poly" for the full-polynomial Ouroboros (with lambda cross-validation), or "arneodo" for the Arneodo 2021 syrinx ODE parameterization (single fit, no regularization CV) + lr: learning rate + n_layers: number of mamba layers in each encoder + d_state: internal SSM state size of the mamba encoders. The default (1) is + small; bumping it (e.g. 4) substantially improves fit -- with enough data + the arneodo model reaches R^2 > 0.98 at d_state=4. Mind GPU memory: the + parallel scan allocates ~batch * npo2(2*seq) * 2*expand_factor * d_state. + d_conv: width of the mamba convolutional kernel + expand_factor: channel expansion from audio to mamba input returns -------- best model after hyperparameter cross-validation @@ -95,11 +108,11 @@ def train_model( dls=dataloaders, dt=dt, n_epochs=n_epochs, - lr=1e-3, - expand_factor=10, - n_layers=3, - d_state=1, - d_conv=4, + lr=lr, + expand_factor=expand_factor, + n_layers=n_layers, + d_state=d_state, + d_conv=d_conv, tau=dt, model_path=model_dir, save_freq=save_freq, @@ -109,12 +122,12 @@ def train_model( dls=dataloaders, dt=dt, n_epochs=n_epochs, - lr=1e-3, + lr=lr, n_kernels=15, - expand_factor=10, - n_layers=3, - d_state=1, - d_conv=4, + expand_factor=expand_factor, + n_layers=n_layers, + d_state=d_state, + d_conv=d_conv, tau=dt, model_path=model_dir, save_freq=save_freq, From 15082e4466ddcfe307ad83c00cf0d5a2dfc49811 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Wed, 27 May 2026 14:00:49 -0400 Subject: [PATCH 03/15] plot_autonomous: configurable model/data dirs and IC start offset - --model-dir / --data-dir to point at any checkpoint and gabo dataset (not just the conventional /{model/arneodo,gabo_data} layout). - Replace --pre-onset-ms with --start-offset-ms (relative to onset): positive seeds the rollout on the established limit cycle (stable IC), negative is a cold start from near-silence. Mid-vocalization ICs avoid the unstable near-origin region. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/plot_autonomous.py | 40 ++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/examples/plot_autonomous.py b/examples/plot_autonomous.py index 5c46e86..7f85767 100644 --- a/examples/plot_autonomous.py +++ b/examples/plot_autonomous.py @@ -33,44 +33,52 @@ def main(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--out-dir", default="./arneodo_run") - parser.add_argument("--n-integrate", type=int, default=11000, help="samples to integrate") - parser.add_argument("--pre-onset-ms", type=float, default=10.0, - help="start the rollout this many ms before the vocalization onset") + parser.add_argument("--n-integrate", type=int, default=8000, help="samples to integrate") + parser.add_argument("--start-offset-ms", type=float, default=50.0, + help="begin the rollout this many ms relative to onset; positive = into the " + "vocalization (seed on the limit cycle), negative = before onset (cold start)") parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--fmax", type=float, default=10000.0, help="max spectrogram freq (Hz)") parser.add_argument("--method", default="rk4") + parser.add_argument("--model-dir", default=None, + help="checkpoint dir (default /model/arneodo)") + parser.add_argument("--data-dir", default=None, + help="gabo data dir (default /gabo_data)") args = parser.parse_args() out_dir = os.path.abspath(args.out_dir) - data_dir = os.path.join(out_dir, "gabo_data") - model_dir = os.path.join(out_dir, "model", "arneodo") + data_dir = os.path.abspath(args.data_dir) if args.data_dir else os.path.join(out_dir, "gabo_data") + model_dir = os.path.abspath(args.model_dir) if args.model_dir else os.path.join(out_dir, "model", "arneodo") # load the trained model ----------------------------------------------------------- model, _, _, epoch = load_model(model_dir) model.eval() print(f"loaded {type(model).__name__} (epoch {epoch})") - # held-out vocalization, windowed to start `pre_onset_ms` BEFORE the first onset, so the - # rollout is seeded from (near) rest just before vocalization and must spin the - # oscillation up itself. analysis mode returns aud[onset - padding : offset], so passing - # padding = pre_onset gives exactly that window (scale-consistent with training loading). - pre_onset_s = args.pre_onset_ms / 1e3 + # held-out vocalization. analysis mode returns aud[onset - padding : offset]; we then begin + # the rollout at `start_offset_ms` relative to onset. A positive offset seeds the integrator + # on the established limit cycle (the stable, sensible IC); a negative offset is a cold start + # from near-silence (which can be near an unstable fixed point of the learned field). + analysis_pad_s = max(0.03, -args.start_offset_ms / 1e3 + 0.005) chunks, sr = get_segmented_audio( data_dir, data_dir, max_vocs=40, seed=args.seed + 1, training=False, - padding=pre_onset_s, + padding=analysis_pad_s, shuffle_order=True, ) dt = 1 / sr - segment = np.asarray(chunks[0]).squeeze() - n = min(args.n_integrate, len(segment)) - segment = segment[:n] + full = np.asarray(chunks[0]).squeeze() + onset_idx = int(round(analysis_pad_s * sr)) # onset position within the analysis window + start_idx = max(0, onset_idx + int(round(args.start_offset_ms / 1e3 * sr))) + segment = full[start_idx : start_idx + args.n_integrate] + n = len(segment) + rel = "after" if args.start_offset_ms >= 0 else "before" print( - f"rollout starts {args.pre_onset_ms:.0f} ms before onset; " + f"rollout starts {abs(args.start_offset_ms):.0f} ms {rel} onset; " f"autonomously integrating {n} samples ({n * dt * 1e3:.1f} ms) ..." ) x_gen = integrate_model_autonomous( @@ -121,7 +129,7 @@ def spec(x): fig.colorbar(pcm, ax=ax, label="power (dB)") fig.suptitle( - f"Arneodo autonomous integration | start {args.pre_onset_ms:.0f} ms pre-onset | " + f"Arneodo autonomous integration | start {args.start_offset_ms:+.0f} ms rel. onset | " f"{n * dt * 1e3:.0f} ms | sr {sr} Hz", y=1.00, ) From 0ed5c3be848386078cef765eaffa6067aa4d4401 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Wed, 27 May 2026 19:12:39 -0400 Subject: [PATCH 04/15] Add drive low-pass in the loop + polynomial-model autonomous reconstruction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ports the "low-pass the drives" idea to the original polynomial Ouroboros and adds a closed-loop autonomous integrator for it, to test how well the original model reconstructs vocalizations autonomously. - model/model.py: Ouroboros gains drive_lowpass_ms — a zero-phase Gaussian low-pass on omega(t), gamma(t) and the kernel weights w(t) (recompute the nonlinearity from the low-passed weights). Shares the low-pass infrastructure used by ArneodoOuroboros. - train/train.py: save_model/load_model record & restore drive_lowpass_ms for both parameterizations. - train/eval.py: integrate_poly_autonomous — closed-loop integration of the polynomial RHS (-omega^2 x - gamma x' - kernel(x,x';w)) feeding the generated state back, with an optional Euler-Maruyama process-noise term (noise_sd) to sustain a noise-driven oscillation. - examples/train_poly_lowpass.py: train the polynomial Ouroboros with 2ms low-pass on the 500-voc set (monitored; reg_weights=True). - examples/plot_poly_compare.py: target vs teacher-forced vs autonomous (waveform + spectrogram). - examples/finetune_rollout.py, examples/plot_lowpass_recon.py: earlier rollout-fine-tune and teacher-forced low-pass-reconstruction experiments. Result on the 500-voc data (2ms low-pass): teacher-forced reconstruction is near-perfect (pitch/amplitude/spectrogram match); the autonomous deterministic dynamics are a damped resonator (decays) whose resonance tracks the data pitch, which a small process noise sustains at the target amplitude -- far better than the non-low-pass autonomous (which diverged / had ~40x amplitude). NOTE: trained at a fixed kernel regularization lambda=1.2 (not yet selected on validation); a lambda sweep is the next step. Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 4 + examples/finetune_rollout.py | 145 ++++++++++++++++++++++++++++++ examples/plot_lowpass_recon.py | 156 +++++++++++++++++++++++++++++++++ examples/plot_poly_compare.py | 113 ++++++++++++++++++++++++ examples/train_poly_lowpass.py | 110 +++++++++++++++++++++++ model/model.py | 75 ++++++++++++++-- train/eval.py | 104 ++++++++++++++++++++++ train/train.py | 3 + 8 files changed, 704 insertions(+), 6 deletions(-) create mode 100644 examples/finetune_rollout.py create mode 100644 examples/plot_lowpass_recon.py create mode 100644 examples/plot_poly_compare.py create mode 100644 examples/train_poly_lowpass.py diff --git a/.gitignore b/.gitignore index 7f6fee9..6691fda 100644 --- a/.gitignore +++ b/.gitignore @@ -166,4 +166,8 @@ cython_debug/ arneodo_run/ arneodo_big/ arneodo_feas/ +arneodo_ft/ +arneodo_lp/ +arneodo_sweep/ +poly_lp/ data500/ diff --git a/examples/finetune_rollout.py b/examples/finetune_rollout.py new file mode 100644 index 0000000..edfa4e5 --- /dev/null +++ b/examples/finetune_rollout.py @@ -0,0 +1,145 @@ +""" +Short rollout-training fine-tune of an ArneodoOuroboros for cold-start stability. + +Teacher-forced training only optimizes one-step ẍ prediction; the resulting free-running +(autonomous) dynamics can be unstable from a cold start (near-silence IC, near the model's +unstable fixed point at the origin). Here we fine-tune by backpropagating through a short +*autonomous rollout*: the model produces the control series alpha/beta/delta from the data, +we integrate the ODE from a cold-start IC feeding the generated state back in (a differentiable +RK4, state-clamped to avoid NaN), and add an MSE between the rolled-out waveform and the data. +A teacher-forced anchor (one-step ẍ MSE) keeps the model oscillating. A curriculum grows the +rollout horizon so the (initially-diverging) model can stabilize progressively. + +Run from the repo root, e.g.: + python -m examples.finetune_rollout --init arneodo_big/arneodo --data-glob './data500/gabo_p*' \ + --out-dir ./arneodo_ft --epochs 12 --batch-size 8 +""" + +import argparse +import glob +import os + +import numpy as np +import torch +from torch.optim import Adam + +from data.load_data import get_segmented_audio +from utils import deriv_approx_dy, deriv_approx_d2y +from train.train import load_model, save_model + +BOUND_X, BOUND_XP = 3.0, 50.0 # clamp the rollout state to keep losses finite (no NaN) + + +def rhs(x, xp, a, b, d, g): + g2 = g * g + dxp = g2 * a + g2 * b * x + g2 * x * x - g2 * x * x * x - g * d * xp - g * x * xp - g * x * x * xp + return xp, dxp + + +def rollout(a, b, d, g, x0, xp0, H): + """differentiable RK4 autonomous rollout (rescaled time, step ds=1). a,b,d: (B, L).""" + x, xp = x0, xp0 + xs = [x] + for k in range(H - 1): + a0, b0, d0 = a[:, k], b[:, k], d[:, k] + a1, b1, d1 = a[:, k + 1], b[:, k + 1], d[:, k + 1] + ah, bh, dh = 0.5 * (a0 + a1), 0.5 * (b0 + b1), 0.5 * (d0 + d1) + k1x, k1v = rhs(x, xp, a0, b0, d0, g) + k2x, k2v = rhs(x + 0.5 * k1x, xp + 0.5 * k1v, ah, bh, dh, g) + k3x, k3v = rhs(x + 0.5 * k2x, xp + 0.5 * k2v, ah, bh, dh, g) + k4x, k4v = rhs(x + k3x, xp + k3v, a1, b1, d1, g) + x = x + (k1x + 2 * k2x + 2 * k3x + k4x) / 6 + xp = xp + (k1v + 2 * k2v + 2 * k3v + k4v) / 6 + x = torch.clamp(x, -BOUND_X, BOUND_X) + xp = torch.clamp(xp, -BOUND_XP, BOUND_XP) + xs.append(x) + return torch.stack(xs, dim=1) # (B, H) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--init", required=True, help="checkpoint dir to fine-tune from") + p.add_argument("--data-glob", required=True) + p.add_argument("--out-dir", default="./arneodo_ft") + p.add_argument("--n-vocs", type=int, default=120, help="cold-start segments to fine-tune on") + p.add_argument("--pre-onset-ms", type=float, default=20.0) + p.add_argument("--hmax", type=int, default=2000, help="max rollout horizon (samples)") + p.add_argument("--epochs", type=int, default=12) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--lr", type=float, default=3e-4) + p.add_argument("--lam-roll", type=float, default=1.0, help="weight on (normalized) rollout loss") + p.add_argument("--seed", type=int, default=0) + args = p.parse_args() + + os.makedirs(os.path.join(os.path.abspath(args.out_dir), "arneodo"), exist_ok=True) + run_dir = os.path.join(os.path.abspath(args.out_dir), "arneodo") + + model, _, _, ep0 = load_model(args.init) + model.train() + dt = model.tau + print(f"fine-tuning from {args.init} (epoch {ep0}), gamma0={float(model.gamma):.4f}") + + # cold-start segments: analysis mode -> aud[onset - pad : offset], trimmed to hmax + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + pad = args.pre_onset_ms / 1e3 + segs = [] + per = max(1, args.n_vocs // len(dirs)) + sr = None + for d in dirs: + ch, sr = get_segmented_audio(d, d, max_vocs=per, training=False, padding=pad, + seed=args.seed, shuffle_order=True) + for c in ch: + c = np.asarray(c).squeeze() + if len(c) >= args.hmax: + segs.append(c[: args.hmax]) + X = np.stack(segs, axis=0)[:, :, None].astype(np.float64) # (N, hmax, 1) + onset_i = int(round(pad * sr)) + print(f"{X.shape[0]} cold-start segments, hmax={args.hmax}, onset at sample {onset_i}") + + dxdt = deriv_approx_dy(X) + dx2 = deriv_approx_d2y(X) # teacher-forced target (per-sample = d2x/ds2 since tau=dt) + Xt = torch.tensor(X, dtype=torch.float32, device="cuda") + Dt = torch.tensor(dxdt, dtype=torch.float32, device="cuda") + D2 = torch.tensor(dx2, dtype=torch.float32, device="cuda") + var_x = float(Xt.var()); var_d2 = float(D2.var()) + + opt = Adam(model.parameters(), lr=args.lr) + N = Xt.shape[0] + # curriculum: grow horizon geometrically, starting BELOW the cold-start divergence onset + # (~40 samples) so early epochs get finite, meaningful rollout gradients. + H_sched = np.unique(np.round(np.geomspace(32, args.hmax, args.epochs)).astype(int)) + + for epoch in range(args.epochs): + H = int(H_sched[epoch]) + perm = torch.randperm(N) + tot_tf = tot_roll = 0.0 + nb = 0 + for i in range(0, N, args.batch_size): + idx = perm[i : i + args.batch_size] + x, dxd, d2 = Xt[idx], Dt[idx], D2[idx] + alpha, beta, delta, z = model._encode(x, dxd, dt) + yhat = model._rhs(alpha, beta, delta, z) # teacher-forced d2x/ds2 + L_tf = ((yhat - d2) ** 2).mean() / var_d2 + + a = alpha[:, :, 0]; b = beta[:, :, 0]; d = delta[:, :, 0]; g = model.gamma + x0 = x[:, 0, 0].detach() + xp0 = ((model.tau / dt) * dxd[:, 0, 0]).detach() + xg = rollout(a, b, d, g, x0, xp0, H) # (B, H), cold start + L_roll = ((xg - x[:, :H, 0]) ** 2).mean() / var_x + + loss = L_tf + args.lam_roll * L_roll + opt.zero_grad(); loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) + opt.step() + tot_tf += float(L_tf); tot_roll += float(L_roll); nb += 1 + print(f"[epoch {epoch+1}/{args.epochs} H={H}] relMSE tf={tot_tf/nb:.4f} roll={tot_roll/nb:.4f}", + flush=True) + + save_model(model, opt, os.path.join(run_dir, f"checkpoint_{ep0+args.epochs}.tar"), + n_layers=model.alpha_mamba.config.n_layers, d_state=model.alpha_mamba.config.d_state, + d_conv=model.alpha_mamba.config.d_conv, expand_factor=model.alpha_mamba.config.expand_factor) + print(f"saved fine-tuned model; gamma={float(model.gamma):.4f}") + + +if __name__ == "__main__": + main() diff --git a/examples/plot_lowpass_recon.py b/examples/plot_lowpass_recon.py new file mode 100644 index 0000000..61c8393 --- /dev/null +++ b/examples/plot_lowpass_recon.py @@ -0,0 +1,156 @@ +""" +Teacher-forced reconstruction from raw vs. low-pass-filtered drives. + +For a trained ArneodoOuroboros on one vocalization: + 1. encode the data -> drives alpha(t), beta(t), delta(t) (teacher-forced encoder pass) + 2. optionally low-pass the drives to their slow (control-rate) component + 3. recompute the second derivative ẍ from those drives at the TRUE data states x, x' + (teacher forcing -- the oscillation is carried by the true states, not the drives) + 4. double-integrate ẍ to reconstruct the waveform, and overlay vs the original + +This tests whether the carrier-frequency content of the per-sample drives is actually needed +for reconstruction, or whether slow drives + the true states suffice. + +Run from the repo root, e.g.: + python -m examples.plot_lowpass_recon --model-dir arneodo_big/arneodo \ + --data-dir data500/gabo_p0 --cutoff-hz 50 --out-dir ./arneodo_big +""" + +import argparse +import os + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np +import torch +from scipy.io import wavfile +from scipy.signal import butter, sosfiltfilt + +from utils import deriv_approx_dy, deriv_approx_d2y +from train.train import load_model +from train.eval import integrate_second_deriv, correct + +plt.rcParams["text.usetex"] = False + + +def lowpass(x, cutoff_hz, fs, order=4): + sos = butter(order, cutoff_hz, btype="low", fs=fs, output="sos") + return sosfiltfilt(sos, x).copy() + + +def r2(recon, orig): + sse = np.sum((recon - orig) ** 2) + sst = np.sum((orig - orig.mean()) ** 2) + return 1 - sse / sst + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--model-dir", default="arneodo_big/arneodo") + p.add_argument("--data-dir", default="data500/gabo_p0") + p.add_argument("--voc", type=int, default=0) + p.add_argument("--vocalization", type=int, default=0) + p.add_argument("--start-offset-ms", type=float, default=20.0, + help="reconstruct starting this many ms after onset (stay in sustained voc)") + p.add_argument("--n", type=int, default=4000, help="samples to reconstruct (avoid the offset decay)") + p.add_argument("--cutoff-hz", type=float, default=50.0) + p.add_argument("--out-dir", default="arneodo_big") + args = p.parse_args() + + tag = f"gabo_artificial_{args.voc}" + model, _, _, epoch = load_model(args.model_dir) + model.eval() + dt = model.tau + print(f"loaded {type(model).__name__} (epoch {epoch})") + + # one vocalization window + sr, audio_full = wavfile.read(os.path.join(args.data_dir, f"{tag}.wav")) + audio_full = audio_full.astype(np.float64) + onoffs = np.atleast_2d(np.loadtxt(os.path.join(args.data_dir, f"{tag}.txt"))) + on_s, off_s = onoffs[args.vocalization] + # sustained sub-window starting after onset (avoids the vocalization's decaying offset, + # where open-loop double-integration destabilizes) + start_idx = int(round(on_s * sr)) + int(round(args.start_offset_ms / 1e3 * sr)) + b = min(len(audio_full), start_idx + args.n) + x = audio_full[start_idx:b] + L = len(x) + t_ms = np.arange(L) * dt * 1e3 + + # encode -> drives + state z = [x, x'] (x' = dx/ds) + xt = torch.from_numpy(x[None, :, None]).to(torch.float32).cuda() + dy = deriv_approx_dy(x[None, :, None]) + dyt = torch.from_numpy(dy).to(torch.float32).cuda() + with torch.no_grad(): + alpha, beta, delta, z = model._encode(xt, dyt, dt) + g = float(model.gamma) + al = alpha.cpu().numpy().squeeze(); be = beta.cpu().numpy().squeeze(); de = delta.cpu().numpy().squeeze() + + def xddot(a_, b_, d_): + # ẍ (= d2x/ds2) from given drives at the TRUE states (teacher forcing) + ad, bd, dd = (torch.from_numpy(v[None, :, None]).to(torch.float32).cuda() for v in (a_, b_, d_)) + with torch.no_grad(): + return model._rhs(ad, bd, dd, z).cpu().numpy().squeeze() + + d2_raw = xddot(al, be, de) + al_lp, be_lp, de_lp = (lowpass(v, args.cutoff_hz, sr) for v in (al, be, de)) + d2_lp = xddot(al_lp, be_lp, de_lp) + + # drift-free metric: how well does ẍ (from raw vs low-pass drives) match the data's ẍ? + data_d2 = deriv_approx_d2y(x[None, :, None]).squeeze() + print(f"teacher-forced ẍ R² vs data ẍ: raw drives={r2(d2_raw, data_d2):+.3f} " + f"low-pass(<{args.cutoff_hz:.0f}Hz)={r2(d2_lp, data_d2):+.3f}") + + # double-integrate ẍ (teacher-forced open loop) in rescaled time s = t/tau + s_steps = (np.arange(L) * dt) / model.tau + x0 = float(x[0]); xp0 = (model.tau / dt) * float(dy[0, 0, 0]) + ic = torch.tensor([x0, xp0], dtype=torch.float32, device="cuda") + recon_raw = integrate_second_deriv(d2_raw, ic, s_steps, method="rk4", verbose=False) + recon_lp = integrate_second_deriv(d2_lp, ic, s_steps, method="rk4", verbose=False) + orig = correct(x) # detrend the original the same way the reconstructions are + + n = min(len(orig), len(recon_raw), len(recon_lp)) + orig, recon_raw, recon_lp = orig[:n], recon_raw[:n], recon_lp[:n] + r2_raw, r2_lp = r2(recon_raw, orig), r2(recon_lp, orig) + # short-window R2 (first 10 ms) before open-loop phase drift dominates + sw = min(n, int(round(0.01 / dt))) + r2_raw_sw, r2_lp_sw = r2(recon_raw[:sw], orig[:sw]), r2(recon_lp[:sw], orig[:sw]) + print(f"reconstruction R2 vs original (full {n*dt*1e3:.0f}ms): raw={r2_raw:+.3f} lp={r2_lp:+.3f}") + print(f"reconstruction R2 vs original (first 10ms): raw={r2_raw_sw:+.3f} lp={r2_lp_sw:+.3f}") + + A = 5 * float(np.std(orig)) # y-limit at the data scale (low-pass recon diverges past this) + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(13, 7)) + tt = t_ms[:n] + ax1.plot(tt, orig, color="tab:orange", lw=1.0, label="original") + ax1.plot(tt, recon_raw, color="tab:green", lw=1.0, alpha=0.7, + label=f"recon, RAW drives (100ms R²={r2_raw:.2f})") + ax1.plot(tt, recon_lp, color="tab:blue", lw=1.0, alpha=0.8, + label=f"recon, LOW-PASS <{args.cutoff_hz:.0f}Hz drives (diverges)") + ax1.set_ylim(-A, A) + ax1.set_title("teacher-forced reconstruction (open-loop): raw drives track; low-pass drives diverge") + ax1.set_xlabel("time (ms)"); ax1.set_ylabel("a.u."); ax1.legend(fontsize=8, loc="upper right") + + # zoom mid-window (skip the edge-IC startup transient) + s0 = min(n - 1, int(round(0.04 / dt))) + zsl = slice(s0, min(n, s0 + int(round(0.02 / dt)))) + ax2.plot(tt[zsl], orig[zsl], color="tab:orange", lw=1.6, label="original") + ax2.plot(tt[zsl], recon_raw[zsl], color="tab:green", lw=1.4, alpha=0.8, label="raw-drive recon") + ax2.plot(tt[zsl], recon_lp[zsl], color="tab:blue", lw=1.2, alpha=0.7, label="low-pass-drive recon") + ax2.set_ylim(-A, A) + ax2.set_title(f"20 ms zoom (mid-window): raw tracks pitch/amplitude; low-pass <{args.cutoff_hz:.0f}Hz does not") + ax2.set_xlabel("time (ms)"); ax2.set_ylabel("a.u."); ax2.legend(fontsize=8, loc="upper right") + + fig.suptitle(f"teacher-forced ẍ R² vs data: raw drives {r2(d2_raw, data_d2):+.2f} | " + f"low-pass <{args.cutoff_hz:.0f}Hz {r2(d2_lp, data_d2):+.2f}", y=1.00) + + fig.tight_layout() + out_path = os.path.join(os.path.abspath(args.out_dir), "lowpass_drive_reconstruction.png") + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved figure to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/plot_poly_compare.py b/examples/plot_poly_compare.py new file mode 100644 index 0000000..87025d2 --- /dev/null +++ b/examples/plot_poly_compare.py @@ -0,0 +1,113 @@ +""" +Compare, for the polynomial Ouroboros (trained with drive low-pass), the target vocalization +against its teacher-forced and autonomous reconstructions -- waveform + spectrogram. + +3x2 grid: rows = [target, teacher-forced recon, autonomous recon]; cols = [waveform, spectrogram] + + - teacher-forced: integrate the model's predicted ẍ (computed at the TRUE states) -> integrate_model_d2 + - autonomous: integrate the ODE feeding the generated state back -> integrate_poly_autonomous + +Optionally injects additive noise into the autonomous waveform to match the target's noise floor. + +Run from the repo root: + python -m examples.plot_poly_compare --model-dir poly_lp/poly --data-dir data500/gabo_p0 \ + --start-offset-ms 50 --n 4000 +""" + +import argparse +import os + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np +from scipy.io import wavfile +from scipy.signal import spectrogram, welch + +from train.train import load_model +from train.eval import integrate_model_d2, integrate_poly_autonomous, correct + +plt.rcParams["text.usetex"] = False + + +def peak_freq(x, sr): + f, P = welch(x - np.mean(x), fs=sr, nperseg=min(1024, len(x))) + P[0] = 0 + return float(f[np.argmax(P)]) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--model-dir", default="poly_lp/poly") + p.add_argument("--data-dir", default="data500/gabo_p0") + p.add_argument("--voc", type=int, default=0) + p.add_argument("--vocalization", type=int, default=0) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--n", type=int, default=4000) + p.add_argument("--noise-sd", type=float, default=0.0, + help="additive noise on the autonomous waveform (match target noise floor)") + p.add_argument("--fmax", type=float, default=10000.0) + p.add_argument("--out-dir", default="poly_lp") + args = p.parse_args() + + tag = f"gabo_artificial_{args.voc}" + model, _, _, epoch = load_model(args.model_dir) + model.eval() + dt = model.tau + print(f"loaded {type(model).__name__} (epoch {epoch}); drive_lowpass_ms={getattr(model,'drive_lowpass_ms',0.0)}") + + sr, audio_full = wavfile.read(os.path.join(args.data_dir, f"{tag}.wav")) + audio_full = audio_full.astype(np.float64) + onoffs = np.atleast_2d(np.loadtxt(os.path.join(args.data_dir, f"{tag}.txt"))) + on_s = onoffs[args.vocalization][0] + start = int(round(on_s * sr)) + int(round(args.start_offset_ms / 1e3 * sr)) + seg = audio_full[start:start + args.n] + print(f"reconstructing {len(seg)} samples ({len(seg)*dt*1e3:.0f} ms) from +{args.start_offset_ms:.0f}ms after onset") + + print("teacher-forced reconstruction ...", flush=True) + tf = integrate_model_d2(model, seg, dt, smoothing=True, verbose=False) + print(f"autonomous reconstruction (process noise_sd={args.noise_sd}) ...", flush=True) + auto = integrate_poly_autonomous(model, seg, dt, noise_sd=args.noise_sd, verbose=False) + + target = correct(seg) # detrend the target like the reconstructions + nmin = min(len(target), len(tf), len(auto)) + target, tf, auto = target[:nmin], tf[:nmin], auto[:nmin] + rows = [("target", target), ("teacher-forced", tf), ("autonomous", auto)] + for name, w in rows: + fin = np.isfinite(w).all() + pf = peak_freq(w, sr) if fin else float("nan") + print(f" {name:14s}: peak {pf:7.0f} Hz range [{np.nanmin(w):.3f},{np.nanmax(w):.3f}] finite={fin}") + + t_ms = np.arange(nmin) * dt * 1e3 + A = max(np.nanstd(target), np.nanstd(tf)) * 6 # shared waveform y-scale (data-ish) + fig, axes = plt.subplots(3, 2, figsize=(15, 9)) + + def spec(x): + npg = min(256, len(x)) + f, tt, S = spectrogram(np.nan_to_num(x), fs=sr, nperseg=npg, noverlap=int(npg * 0.9)) + return f, tt * 1e3, 10 * np.log10(S + 1e-12) + + smax = max(spec(w)[2].max() for _, w in rows) + colors = {"target": "tab:orange", "teacher-forced": "tab:green", "autonomous": "tab:blue"} + for r, (name, w) in enumerate(rows): + axes[r, 0].plot(t_ms, w, color=colors[name], lw=0.8) + axes[r, 0].set_ylim(-A, A) + axes[r, 0].set_ylabel("a.u."); axes[r, 0].set_title(f"{name} waveform (peak {peak_freq(np.nan_to_num(w), sr):.0f} Hz)") + f, tt, S = spec(w) + pcm = axes[r, 1].pcolormesh(tt, f, S, shading="auto", vmin=smax - 80, vmax=smax, cmap="magma") + axes[r, 1].set_ylim(0, args.fmax); axes[r, 1].set_ylabel("freq (Hz)"); axes[r, 1].set_title(f"{name} spectrogram") + fig.colorbar(pcm, ax=axes[r, 1], label="dB") + axes[2, 0].set_xlabel("time (ms)"); axes[2, 1].set_xlabel("time (ms)") + fig.suptitle(f"poly Ouroboros (drive low-pass {getattr(model,'drive_lowpass_ms',0.0)}ms): " + f"target vs teacher-forced vs autonomous | {tag}", y=1.00) + fig.tight_layout() + out_path = os.path.join(os.path.abspath(args.out_dir), "poly_target_tf_autonomous.png") + fig.savefig(out_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved figure to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/train_poly_lowpass.py b/examples/train_poly_lowpass.py new file mode 100644 index 0000000..da13a21 --- /dev/null +++ b/examples/train_poly_lowpass.py @@ -0,0 +1,110 @@ +""" +Train the ORIGINAL polynomial Ouroboros with drive low-pass "in the loop" (omega, gamma, +and the kernel weights are Gaussian low-passed at drive_lowpass_ms), monitored toward high R^2. + +Same idea the other Claude applied to ArneodoOuroboros, pushed back to the original model so we +can test whether slow drives improve autonomous reconstruction for the polynomial parameterization. + +Run from the repo root, e.g.: + python -m examples.train_poly_lowpass --data-glob './data500/gabo_p*' --out-dir ./poly_lp \ + --drive-lowpass-ms 2.0 --d-state 4 --epochs 60 --seg 10 --batch-size 8 +""" + +import argparse +import glob +import os + +import numpy as np +import torch +from torch.optim import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from data.load_data import get_segmented_audio +from data.data_utils import get_loaders +from model.model import Ouroboros +from model.kernels import fullPolyModule +from train.train import train, save_model +from train.eval import eval_model_error +from utils import sse + + +def gather_loaders(data_dirs, max_vocs, context_len, batch_size, seed, n_jobs): + chunks, sr = [], None + per = max(1, max_vocs // len(data_dirs)) + for d in data_dirs: + audio, sr = get_segmented_audio(d, d, max_vocs=per, context_len=context_len, seed=seed, + training=True, extend=True, shuffle_order=True) + chunks += audio + print(f"gathered {len(chunks)} chunks from {len(data_dirs)} dir(s); sr={sr}", flush=True) + dt = 1 / sr + dls = get_loaders(np.stack(chunks, 0), num_workers=n_jobs, batch_size=batch_size, + train_size=0.6, cv=True, seed=seed, dt=dt) + return dls, dt + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data-glob", required=True) + p.add_argument("--out-dir", default="./poly_lp") + p.add_argument("--max-vocs", type=int, default=100000) + p.add_argument("--context-len", type=float, default=0.25) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--epochs", type=int, default=60) + p.add_argument("--seg", type=int, default=10) + p.add_argument("--target-r2", type=float, default=0.999) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--n-layers", type=int, default=3) + p.add_argument("--d-state", type=int, default=4) + p.add_argument("--d-conv", type=int, default=4) + p.add_argument("--expand-factor", type=int, default=10) + p.add_argument("--n-kernels", type=int, default=15) + p.add_argument("--lam", type=float, default=1.2, help="kernel-weight regularization base") + p.add_argument("--drive-lowpass-ms", type=float, default=2.0) + p.add_argument("--seed", type=int, default=1234) + p.add_argument("--n-jobs", type=int, default=8) + args = p.parse_args() + + run_dir = os.path.join(os.path.abspath(args.out_dir), "poly") + os.makedirs(run_dir, exist_ok=True) + data_dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + assert data_dirs, f"no dirs matched {args.data_glob}" + + dls, dt = gather_loaders(data_dirs, args.max_vocs, args.context_len, args.batch_size, + args.seed, args.n_jobs) + + kernel = fullPolyModule(nTerms=args.n_kernels, device="cuda", x_dim=1, z_dim=2, + activation=lambda x: x, lam=args.lam) + model = Ouroboros(d_data=1, kernel=kernel, n_layers=args.n_layers, d_state=args.d_state, + d_conv=args.d_conv, expand_factor=args.expand_factor, tau=dt, + drive_lowpass_ms=args.drive_lowpass_ms) + n_params = sum(q.numel() for q in model.parameters()) + print(f"poly Ouroboros: n_kernels={args.n_kernels} d_state={args.d_state} expand={args.expand_factor} " + f"lowpass={args.drive_lowpass_ms}ms lam={args.lam} -> {n_params} params; tau={dt:.2e}", flush=True) + + opt = Adam(model.parameters(), lr=args.lr) + scheduler = ReduceLROnPlateau(opt, factor=0.5, patience=max(args.seg, 3), min_lr=1e-10) + model_info = {"n layers": args.n_layers, "d state": args.d_state, + "d conv": args.d_conv, "expand factor": args.expand_factor} + + best = -np.inf + for start in range(0, args.epochs, args.seg): + end = min(args.epochs, start + args.seg) + train(model, opt, loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), + loaders=dls, scheduler=scheduler, nEpochs=end, val_freq=1, runDir=run_dir, + dt=dt, vis_freq=0, smoothing=False, reg_weights=True, start_epoch=start, + save_freq=max(args.seg, 1), model_info=model_info) + model.eval() + with torch.no_grad(): + (tr, te), (trsd, tesd), _ = eval_model_error(dls, model, dt=dt, comparison="test") + print(f"[epoch {end}] train R2={tr:.4f} test R2={te:.4f}", flush=True) + save_model(model, opt, os.path.join(run_dir, f"checkpoint_{end}.tar"), + n_layers=args.n_layers, d_state=args.d_state, d_conv=args.d_conv, + expand_factor=args.expand_factor) + best = max(best, te) + if te >= args.target_r2: + break + print(f"DONE. best test R2 = {best:.4f}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/model/model.py b/model/model.py index 918b85f..615d30b 100644 --- a/model/model.py +++ b/model/model.py @@ -38,6 +38,7 @@ def __init__( device: str = "cuda", tau: float = 1 / 10000, smooth_len: float = 0.001, + drive_lowpass_ms: float = 0.0, ): super().__init__() @@ -82,10 +83,34 @@ def __init__( self.tau = tau self.smooth_len = smooth_len + # if > 0, low-pass the drives omega(t), gamma(t), and the kernel weights w(t) with a + # zero-phase Gaussian (sigma = drive_lowpass_ms) -- "low-pass in the loop". Matches the + # mechanism added to ArneodoOuroboros so the drives stay slow / control-rate. + self.drive_lowpass_ms = drive_lowpass_ms self.kernel = kernel self.kernel.tau = self.tau self.names = [r"$\omega$", r"$\gamma$", "weighted kernels", "states"] + def _lowpass(self, x: torch.FloatTensor, dt: float) -> torch.FloatTensor: + """centered zero-phase Gaussian low-pass along time of a (B, L, C) control series.""" + sigma = (self.drive_lowpass_ms / 1e3) / dt # samples + if sigma <= 0: + return x + radius = max(1, int(round(3 * sigma))) + t = torch.arange(-radius, radius + 1, device=x.device, dtype=x.dtype) + kern = torch.exp(-0.5 * (t / sigma) ** 2) + kern = kern / kern.sum() + C = x.shape[-1] + k = kern.view(1, 1, -1).expand(C, 1, -1) + xc = F.pad(x.transpose(1, 2), (radius, radius), mode="reflect") + return F.conv1d(xc, k, groups=C).transpose(1, 2) + + def _lowpass_weights(self, weights: torch.FloatTensor, dt: float) -> torch.FloatTensor: + """low-pass the polynomial kernel weights (B, L, P, P) along time.""" + B, L, P, P2 = weights.shape + w = self._lowpass(weights.reshape(B, L, P * P2), dt) + return w.reshape(B, L, P, P2) + def forward( self, x: torch.FloatTensor, @@ -134,11 +159,18 @@ def forward( omegaControl ).abs() # Since we take omega^2 anyway, we take the absolute value to prevent things from switching around too much gamma = self.gamma_net(gammaControl) - if smoothing: + weighted_kernels, weights = self.kernel(z, kernelControl) + if self.drive_lowpass_ms > 0: + # low-pass the drives in the loop, then recompute the nonlinearity from the + # low-passed weights so yhat is consistent with the (slow) drives + omega = self._lowpass(omega, dt) + gamma = self._lowpass(gamma, dt) + weights = self._lowpass_weights(weights, dt) + weighted_kernels = self.kernel.forward_given_weights(z, weights) + elif smoothing: # smooth our model functions, if we choose to do so. I do not. omega = smooth(omega, smooth_len) gamma = smooth(gamma, smooth_len) - weighted_kernels, weights = self.kernel(z, kernelControl) z1 = z[:, :, :1] z2 = z[:, :, 1:] @@ -211,13 +243,17 @@ def get_funcs( omega = self.omega_net(omegaControl).abs() gamma = self.gamma_net(gammaControl) + weighted_kernels, weights = self.kernel(z, kernelControl) - if smoothing: + if self.drive_lowpass_ms > 0: + omega = self._lowpass(omega, dt) + gamma = self._lowpass(gamma, dt) + weights = self._lowpass_weights(weights, dt) + weighted_kernels = self.kernel.forward_given_weights(z, weights) + elif smoothing: omega = smooth(omega.abs(), smooth_len) gamma = smooth(gamma, smooth_len) - weighted_kernels, weights = self.kernel(z, kernelControl) - return ( omega, gamma, @@ -430,6 +466,7 @@ def __init__( tau: float = 1 / 10000, smooth_len: float = 0.001, gamma_init: float = 1.0, + drive_lowpass_ms: float = 0.0, ): super().__init__() @@ -491,6 +528,11 @@ def __init__( self.tau = tau self.smooth_len = smooth_len + # if >0, hard-constrain the control series alpha/beta/delta to vary no faster than + # this timescale (ms) by low-pass filtering the Mamba heads' outputs (always, train + # + eval). Encodes a known physiological control-rate prior and prevents the drives + # from carrying carrier-frequency content. Gaussian sigma = drive_lowpass_ms. + self.drive_lowpass_ms = drive_lowpass_ms self.names = [r"$\alpha$", r"$\beta$", r"$\delta$", r"$\gamma$"] @property @@ -498,6 +540,22 @@ def gamma(self) -> torch.FloatTensor: """rescaled, strictly-positive time-scaling scalar g = softplus(log_gamma).""" return F.softplus(self.log_gamma) + def _lowpass(self, x: torch.FloatTensor, dt: float) -> torch.FloatTensor: + """ + centered, zero-phase Gaussian low-pass along time (sigma = self.drive_lowpass_ms), + applied to a control series x of shape (B, L, 1). Differentiable (fixed kernel); + reflection-padded to avoid edge artifacts / phase delay. + """ + sigma = (self.drive_lowpass_ms / 1e3) / dt # samples + if sigma <= 0: + return x + radius = max(1, int(round(3 * sigma))) + t = torch.arange(-radius, radius + 1, device=x.device, dtype=x.dtype) + kern = torch.exp(-0.5 * (t / sigma) ** 2) + kern = (kern / kern.sum()).view(1, 1, -1) + xc = F.pad(x.transpose(1, 2), (radius, radius), mode="reflect") + return F.conv1d(xc, kern).transpose(1, 2) + def _encode( self, x: torch.FloatTensor, @@ -529,7 +587,12 @@ def _encode( beta = self.beta_net(betaControl) delta = self.delta_net(deltaControl) - if smoothing: + if self.drive_lowpass_ms > 0: + # hard timescale constraint: low-pass the drives (always, train + eval) + alpha = self._lowpass(alpha, dt) + beta = self._lowpass(beta, dt) + delta = self._lowpass(delta, dt) + elif smoothing: smooth_len = int(round(self.smooth_len / dt)) alpha = smooth(alpha, smooth_len) beta = smooth(beta, smooth_len) diff --git a/train/eval.py b/train/eval.py index 80422e6..de7c4c5 100644 --- a/train/eval.py +++ b/train/eval.py @@ -333,6 +333,110 @@ def dz_hat(s, z): return x_gen +def integrate_poly_autonomous( + model: torch.nn.Module, + audio: np.ndarray, + dt: float, + method: str = "rk4", + detrend: bool = True, + noise_sd: float = 0.0, + seed: int = 0, + verbose: bool = True, +) -> np.ndarray: + """ + fully autonomous (closed-loop) integration of a polynomial `Ouroboros`. + + Precomputes the drives omega(t), gamma(t) and the kernel weights w(t) from `audio` (these are + low-passed inside `get_funcs` if the model was trained with `drive_lowpass_ms`), then integrates + + dx/ds = x' + dx'/ds = -omega(s)^2 x - gamma(s) x' - kernel(x, x'; w(s)) + + in the model's rescaled time s = t/tau, feeding the generated state (x, x') back into the + nonlinearity each step. Only the drives (from data) and the initial condition come from outside. + """ + + L = len(audio) + t_steps = np.arange(0, L * dt + dt / 2, dt)[:L] + s_steps = t_steps / model.tau + + audio_3d = audio[None, :, None] + dy = deriv_approx_dy(audio_3d) + audio_t = torch.from_numpy(audio_3d).to(torch.float32).to("cuda") + dy_t = torch.from_numpy(dy).to(torch.float32).to("cuda") + + with torch.no_grad(): + omega, gamma, _, weights, _ = model.get_funcs(audio_t, dy_t, dt) + omega = omega.detach().cpu().numpy().squeeze() + gamma = gamma.detach().cpu().numpy().squeeze() + weights = weights.detach().cpu().numpy() # (1, L, P, P) + _, _, P, P2 = weights.shape + w_flat = weights.reshape(L, P * P2) + + omega_interp = make_interp_spline(s_steps, omega) + gamma_interp = make_interp_spline(s_steps, gamma) + w_interp = make_interp_spline(s_steps, w_flat) # vector-valued spline over time + + x0 = float(audio[0]) + xp0 = (model.tau / dt) * float(dy[0, 0, 0]) + ic = torch.tensor([x0, xp0], dtype=torch.float32, device="cuda") + kernel = model.kernel + + if noise_sd > 0: + # stochastic forcing (Euler-Maruyama on the velocity) to sustain a noise-driven + # oscillation at the data amplitude. RK4 drift per sample (drives held constant over + # the 1-sample step, fine since they are low-passed), plus additive noise on x'. + rng = np.random.default_rng(seed) + x, xp = x0, xp0 + xs = [x] + ww = weights.reshape(L, 1, 1, P, P2) + for k in range(L - 1): + om, ga, wk = omega[k], gamma[k], ww[k] + + def f(xx, vv): + kern = float(kernel.forward_given_weights_numpy(np.array([[[xx, vv]]]), wk).squeeze()) + return vv, -(om**2) * xx - ga * vv - kern + + k1x, k1v = f(x, xp) + k2x, k2v = f(x + 0.5 * k1x, xp + 0.5 * k1v) + k3x, k3v = f(x + 0.5 * k2x, xp + 0.5 * k2v) + k4x, k4v = f(x + k3x, xp + k3v) + x = x + (k1x + 2 * k2x + 2 * k3x + k4x) / 6 + xp = xp + (k1v + 2 * k2v + 2 * k3v + k4v) / 6 + noise_sd * rng.standard_normal() + xs.append(x) + x_gen = np.array(xs) + return correct(x_gen) if detrend else x_gen + + def dz_hat(s, z): + if verbose: + print( + f"{(s - s_steps[0]) / (s_steps[-1] - s_steps[0]) * 100:0.3f}%,", + end="\r", + ) + s_np = s.detach().cpu().numpy() + om = float(omega_interp(s_np)) + ga = float(gamma_interp(s_np)) + w = w_interp(s_np).reshape(1, 1, P, P2) + x = float(z[0]) + xp = float(z[1]) + kern = float( + kernel.forward_given_weights_numpy(np.array([[[x, xp]]]), w).squeeze() + ) + dxp = -(om**2) * x - ga * xp - kern + return torch.tensor([xp, dxp], dtype=torch.float32, device=z.device) + + eval_times = torch.from_numpy(s_steps).to(ic.device) + with torch.no_grad(): + sol = odeint_adjoint( + dz_hat, ic, eval_times, adjoint_params=(), method=method, options=dict() + ).transpose(0, 1) + + x_gen = sol[0].detach().cpu().numpy().squeeze() + if detrend: + x_gen = correct(x_gen) + return x_gen + + def eval_model_error( dls: dict, model: torch.nn.Module, dt: float, comparison: str = "val" ) -> tuple[tuple[float, float], tuple[float, float], tuple[np.ndarray, np.ndarray]]: diff --git a/train/train.py b/train/train.py index 09216ca..524a46c 100644 --- a/train/train.py +++ b/train/train.py @@ -67,6 +67,7 @@ def save_model( "d_conv": d_conv, "expand_factor": expand_factor, "parameterization": parameterization, + "drive_lowpass_ms": getattr(model, "drive_lowpass_ms", 0.0), } try: sd["n_kernel"] = model.kernel.nTerms @@ -124,6 +125,7 @@ def load_model( expand_factor=expand_factor, tau=sd["tau"], smooth_len=sd["smooth_len"], + drive_lowpass_ms=sd.get("drive_lowpass_ms", 0.0), ) else: try: @@ -147,6 +149,7 @@ def load_model( tau=sd["tau"], smooth_len=sd["smooth_len"], kernel=kernel, + drive_lowpass_ms=sd.get("drive_lowpass_ms", 0.0), ) except: print("no kernel in savefile!") From 78d8ff4732022d87267330edfc2fdb4b5d75fd95 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Wed, 27 May 2026 19:17:02 -0400 Subject: [PATCH 05/15] Default Arneodo drive low-pass to 1 ms The drive low-pass regularizer is now on by default at 1 ms for the Arneodo parameterization: ArneodoOuroboros.drive_lowpass_ms defaults to 1.0, and train_arneodo / train_model / train_arneodo_big default the knob to 1 ms. A 1-5 ms sweep showed teacher-forced R2 is flat (~0.75) and the drives are slow throughout; 1 ms gave the best autonomous (noise-sustained) spectral fidelity and cold-start stability (all 1-5 ms checkpoints are cold-start bounded; none self-sustain deterministically, but all noise-sustain at the right pitch). Set drive_lowpass_ms=0.0 for the unregularized model (higher one-step R2, but fast drives and an unstable cold start). The polynomial Ouroboros default is left at 0.0. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/train_arneodo_big.py | 7 +++++-- model/model.py | 6 +++++- train/model_cv.py | 5 +++++ train/train_model.py | 5 +++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/train_arneodo_big.py b/examples/train_arneodo_big.py index f6e68c6..9616a51 100644 --- a/examples/train_arneodo_big.py +++ b/examples/train_arneodo_big.py @@ -63,6 +63,9 @@ def main(): p.add_argument("--d-state", type=int, default=16) p.add_argument("--d-conv", type=int, default=4) p.add_argument("--expand-factor", type=int, default=10) + p.add_argument("--drive-lowpass-ms", type=float, default=1.0, + help="hard low-pass alpha/beta/delta at this Gaussian timescale (ms); " + "default 1 ms (slow drives, cold-start stable). 0 disables.") p.add_argument("--seed", type=int, default=1234) p.add_argument("--n-jobs", type=int, default=4) args = p.parse_args() @@ -80,11 +83,11 @@ def main(): model = ArneodoOuroboros( d_data=1, n_layers=args.n_layers, d_state=args.d_state, d_conv=args.d_conv, - expand_factor=args.expand_factor, tau=dt, + expand_factor=args.expand_factor, tau=dt, drive_lowpass_ms=args.drive_lowpass_ms, ) n_params = sum(q.numel() for q in model.parameters()) print(f"model: n_layers={args.n_layers} d_state={args.d_state} d_conv={args.d_conv} " - f"expand={args.expand_factor} -> {n_params} params; tau={dt:.2e}") + f"expand={args.expand_factor} lowpass={args.drive_lowpass_ms}ms -> {n_params} params; tau={dt:.2e}") opt = Adam(model.parameters(), lr=args.lr) scheduler = ReduceLROnPlateau(opt, factor=0.5, patience=max(args.seg, 3), min_lr=1e-10) diff --git a/model/model.py b/model/model.py index 615d30b..d8443c6 100644 --- a/model/model.py +++ b/model/model.py @@ -466,7 +466,7 @@ def __init__( tau: float = 1 / 10000, smooth_len: float = 0.001, gamma_init: float = 1.0, - drive_lowpass_ms: float = 0.0, + drive_lowpass_ms: float = 1.0, ): super().__init__() @@ -532,6 +532,10 @@ def __init__( # this timescale (ms) by low-pass filtering the Mamba heads' outputs (always, train # + eval). Encodes a known physiological control-rate prior and prevents the drives # from carrying carrier-frequency content. Gaussian sigma = drive_lowpass_ms. + # Defaults to 1 ms: across a 1-5 ms sweep R2 is flat (~0.75) and the drives are slow + # either way, and 1 ms gave the best autonomous (noise-sustained) spectral fidelity + # and cold-start stability. Set 0.0 for the unregularized model (higher one-step R2, + # but fast/spiky drives and an unstable cold start). self.drive_lowpass_ms = drive_lowpass_ms self.names = [r"$\alpha$", r"$\beta$", r"$\delta$", r"$\gamma$"] diff --git a/train/model_cv.py b/train/model_cv.py index 01aad2d..401ef85 100644 --- a/train/model_cv.py +++ b/train/model_cv.py @@ -227,6 +227,7 @@ def train_arneodo( smooth_len: float = 0.001, model_path: str = "", save_freq: int = 5, + drive_lowpass_ms: float = 1.0, ) -> torch.nn.Module: """ trains a single `ArneodoOuroboros` model (the biomechanical syrinx parameterization). @@ -249,6 +250,9 @@ def train_arneodo( - smooth_len: smoothing length for model functions (not used in training) - model_path: place to save the model and training artifacts - save_freq: how often (in epochs) to checkpoint the model + - drive_lowpass_ms: hard low-pass timescale (ms) on the alpha/beta/delta drives. + Defaults to 1 ms (slow, physiological drives + cold-start-stable autonomous + dynamics, at a teacher-forced R^2 cost). Set 0.0 for the unregularized model. returns ----- @@ -270,6 +274,7 @@ def train_arneodo( expand_factor=expand_factor, tau=tau, smooth_len=smooth_len, + drive_lowpass_ms=drive_lowpass_ms, ) opt = Adam(model.parameters(), lr=lr) diff --git a/train/train_model.py b/train/train_model.py index 9d06bb8..e9fa4ab 100644 --- a/train/train_model.py +++ b/train/train_model.py @@ -28,6 +28,7 @@ def train_model( d_state: int = 1, d_conv: int = 4, expand_factor: int = 10, + drive_lowpass_ms: float = 1.0, ) -> torch.nn.Module: """ function for training a model. takes audio from @@ -58,6 +59,9 @@ def train_model( parallel scan allocates ~batch * npo2(2*seq) * 2*expand_factor * d_state. d_conv: width of the mamba convolutional kernel expand_factor: channel expansion from audio to mamba input + drive_lowpass_ms: (arneodo only) hard low-pass timescale (ms) on the + alpha/beta/delta drives; default 1 ms gives slow, physiological drives and + cold-start-stable autonomous dynamics. Set 0.0 for the unregularized model. returns -------- best model after hyperparameter cross-validation @@ -116,6 +120,7 @@ def train_model( tau=dt, model_path=model_dir, save_freq=save_freq, + drive_lowpass_ms=drive_lowpass_ms, ) else: best_model = model_cv_lambdas( From b0f897c84f610b7c50c9e5610d7e43e773fb6953 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Wed, 27 May 2026 21:46:29 -0400 Subject: [PATCH 06/15] plot_drives: configurable --model-dir / --data-dir Lets the learned-vs-true-drives comparison target any checkpoint and gabo dataset, not just the conventional /{model/arneodo,gabo_data} layout (matches the same options on plot_autonomous.py). Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/plot_drives.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/plot_drives.py b/examples/plot_drives.py index 208f6d0..67dade9 100644 --- a/examples/plot_drives.py +++ b/examples/plot_drives.py @@ -63,11 +63,13 @@ def main(): parser.add_argument("--pad-ms", type=float, default=20.0, help="window padding each side (ms)") parser.add_argument("--cutoff-hz", type=float, default=50.0, help="low-pass cutoff for the learned drives' slow component") + parser.add_argument("--model-dir", default=None, help="checkpoint dir (default /model/arneodo)") + parser.add_argument("--data-dir", default=None, help="gabo data dir (default /gabo_data)") args = parser.parse_args() out_dir = os.path.abspath(args.out_dir) - data_dir = os.path.join(out_dir, "gabo_data") - model_dir = os.path.join(out_dir, "model", "arneodo") + data_dir = os.path.abspath(args.data_dir) if args.data_dir else os.path.join(out_dir, "gabo_data") + model_dir = os.path.abspath(args.model_dir) if args.model_dir else os.path.join(out_dir, "model", "arneodo") tag = f"gabo_artificial_{args.voc}" model, _, _, epoch = load_model(model_dir) From 78aa725db824dcb003e5af33e22651ca7317246f Mon Sep 17 00:00:00 2001 From: John Pearson Date: Thu, 28 May 2026 20:29:20 -0400 Subject: [PATCH 07/15] Autonomous reconstruction: drive low-pass, seed-selection + amplitude rescale, Floquet diagnosis Investigate and address the free-amplitude problem in autonomous (closed-loop) generation of the polynomial Ouroboros, and fold the resulting recipe into the pipeline. Core: - model: drive low-pass in the loop (drive_lowpass_ms) on omega/gamma/kernel weights; optional (0,0) "alpha" forcing term (keep_const, zero-init) in fullPolyModule; train() total_loss fix for the unregularized path. - eval: integrate_poly_autonomous (closed-loop poly RHS), autonomy_score (phase-robust validation metric, with rescale option), generate_autonomous (deployed rescaled generation). - model_cv_lambdas: multi-seed training + selection by (rescaled) validation autonomy; keep_const; single-lambda mode (lambda is irrelevant for autonomy, the seed dominates). Pipeline (run_lambda_pipeline.py): production recipe -- fix lambda, train several seeds, select the best seed on a held-out shard by rescaled autonomy, keep the (0,0) term, and emit a rescaled autonomous reconstruction (selected_autonomous_recon.wav + selected_model.json). Diagnosis + negative results (docs/autonomous_amplitude.md): autonomous amplitude is a marginal, untrained transverse direction (Floquet exponent Lambda = closed integral of df/dx' left free by on-orbit fitting; floquet_amplitude_diagnostic.py confirms Lambda predicts rollout decay/growth at r~=1.0 and scatters by seed). Divergence penalties and noisy-rollout fine-tunes (finetune_attractor_poly.py) and spectral/envelope refine (rollout_refine.py, finetune_rollout_spectral_poly.py) do not robustly fix amplitude; output-amplitude rescaling + seed selection is the reliable recipe. Co-Authored-By: Claude Opus 4.7 --- .gitignore | 11 + docs/autonomous_amplitude.md | 111 ++++++++++ examples/compare_lambda_autonomous.py | 123 +++++++++++ examples/eval_alpha_autonomy.py | 69 ++++++ examples/eval_autonomy_rescale.py | 129 +++++++++++ examples/finetune_attractor_poly.py | 166 ++++++++++++++ examples/finetune_rollout_poly.py | 142 ++++++++++++ examples/finetune_rollout_spectral_poly.py | 61 ++++++ examples/floquet_amplitude_diagnostic.py | 132 +++++++++++ examples/run_alpha_multiseed.py | 158 ++++++++++++++ examples/run_lambda_pipeline.py | 130 +++++++++++ examples/run_poly_rollout_pipeline.py | 146 +++++++++++++ examples/sweep_lowpass.py | 97 ++++++++ examples/train_poly_lowpass.py | 6 +- examples/train_poly_rollout.py | 129 +++++++++++ model/kernels.py | 12 +- model/model.py | 24 +- train/eval.py | 121 ++++++++++ train/model_cv.py | 243 +++++++++------------ train/rollout_refine.py | 173 +++++++++++++++ train/train.py | 2 + 21 files changed, 2039 insertions(+), 146 deletions(-) create mode 100644 docs/autonomous_amplitude.md create mode 100644 examples/compare_lambda_autonomous.py create mode 100644 examples/eval_alpha_autonomy.py create mode 100644 examples/eval_autonomy_rescale.py create mode 100644 examples/finetune_attractor_poly.py create mode 100644 examples/finetune_rollout_poly.py create mode 100644 examples/finetune_rollout_spectral_poly.py create mode 100644 examples/floquet_amplitude_diagnostic.py create mode 100644 examples/run_alpha_multiseed.py create mode 100644 examples/run_lambda_pipeline.py create mode 100644 examples/run_poly_rollout_pipeline.py create mode 100644 examples/sweep_lowpass.py create mode 100644 examples/train_poly_rollout.py create mode 100644 train/rollout_refine.py diff --git a/.gitignore b/.gitignore index 6691fda..8aaee63 100644 --- a/.gitignore +++ b/.gitignore @@ -170,4 +170,15 @@ arneodo_ft/ arneodo_lp/ arneodo_sweep/ poly_lp/ +poly_lp_sweep/ +poly_ft_*/ +poly_pipeline/ +poly_pipeline_test/ data500/ +poly_alpha/ +poly_alpha_ms/ +poly_ftspec_*/ +poly_ro_pipeline/ +poly_ro_*/ +poly_ftspec_env10_*/ +poly_attr_*/ diff --git a/docs/autonomous_amplitude.md b/docs/autonomous_amplitude.md new file mode 100644 index 0000000..691b288 --- /dev/null +++ b/docs/autonomous_amplitude.md @@ -0,0 +1,111 @@ +# Autonomous reconstruction and the free-amplitude problem + +This note documents why the polynomial `Ouroboros` reconstructs vocalizations near-perfectly under +teacher forcing but produces **poorly-constrained amplitude** when run fully autonomously, what fixes +were tried, and the recipe we landed on. + +## Setup + +The model parameterizes the second derivative of audio as a driven nonlinear oscillator, + +``` +ẍ = -ω(t)² x - γ(t) ẋ - Σ_{ij} w_{ij}(t) xⁱ ẋʲ +``` + +with the drives ω(t), γ(t), w(t) produced by Mamba encoders and low-passed in the loop +(`drive_lowpass_ms`). Training is **teacher-forced**: one-step prediction of ẍ at the true states. +For deployment we also want **autonomous** generation — integrate the ODE feeding the *generated* +state back, with only the drives (precomputed from data) and the initial condition supplied externally +(`train.eval.integrate_poly_autonomous`). + +On Mindlin/gabo synthetic data, teacher-forced fit is strong (test R² ≈ 0.73 with the 1 ms drive +low-pass; ≈ 0.99 without it), and the autonomous **pitch and spectral shape** track the target well. +The **amplitude**, however, is unreliable: the free run decays, grows, or self-sustains depending on +the seed, with the overall loudness off by factors of ~2–5. + +## Diagnosis: amplitude is a marginal, untrained direction + +The training data lies on a 1-D closed orbit (the limit cycle) in the 2-D (x, ẋ) phase plane. +Pointwise ẍ-matching pins the vector field's **tangential** (along-orbit: phase/frequency) component +but says almost nothing about its **transverse** (radial/amplitude) component — the data never leaves +the orbit. Amplitude stability is entirely a transverse property. + +Linearizing around the periodic orbit gives two Floquet multipliers, and they are exactly our two +failure modes: + +- One multiplier is structurally **= 1**, eigenvector along the flow → phase is marginal (this is the + phase drift that breaks pointwise rollout losses). +- The other governs amplitude: `exp(Λ)` with the area-contraction exponent + **`Λ = ∮ ∂f/∂ẋ dt`** (the net per-cycle effective damping). `Λ<0` attracting, `Λ=0` neutral + (conservative / SHO-like, amplitude free), `Λ>0` repelling. + +On-orbit fitting leaves `Λ` essentially free. Confirmed empirically with `examples/floquet_amplitude_diagnostic.py`: +across seeds, `Λ` computed along the data orbit predicts the actual rollout decay/growth at **r ≈ 1.0**, +`|Λ/cycle| ≲ 0.02` (near-neutral — *not* an attractor; a true limit cycle would be 1–2 orders of +magnitude more contracting), and `Λ` **scatters around zero with both signs across seeds**. That scatter +*is* the seed-dominated amplitude. In short: **the quantity that sets autonomous amplitude is not in the +training signal**, so it is fixed only by random initialization. + +## What we tried + +| Approach | What it constrains | Result | +|---|---|---| +| Constant (0,0) "alpha" forcing term | extra DC drive | R²-neutral; **no** robust autonomy gain (single-seed "win" was a low-pass confound; 5-seed median Δ≈0). Kept on anyway (may help other data). | +| Amplitude **rescale** (post-hoc) | output scale (gauge) | **+0.60**, reliable. Needs a reference loudness at inference. | +| **Envelope**-matching rollout refine | output loudness directly | helped 3/5 seeds, hurt 1/5; net slightly negative, seed-variable. | +| Spectral (multi-res STFT) rollout refine | spectral shape | improves shape but not amplitude; seed-fragile. | +| Pointwise rollout MSE (long horizon) | trajectory, pointwise | fails — phase drift makes MSE reward amplitude collapse. | +| **Λ penalty** (`floq`: make the orbit attracting) | divergence sign | converts grow→**decay**; sets stability *type*, not attractor *location*. | +| **Noise** fine-tune (denoise back to orbit) | local contraction | **collapse** (aggressive, Λ→−30) or **blow-up** (gentle, Λ→+30); BPTT through the marginal/expanding rollout is ill-conditioned. | + +Two structural lessons emerged: + +1. **Divergence sign ≠ amplitude.** Driving `Λ<0` on a non-exact orbit contracts toward the model's + *own* (smaller/zero) attractor — it does not place the attractor at the data radius. Amplitude is set + by attractor *location*, a global limit-cycle property, not by the local stability type. +2. **Gradient-based autonomous-rollout fine-tuning is ill-conditioned here.** For most seeds the rollout + is locally expanding (Λ>0), so BPTT through it has exploding gradients; large injected noise overrides + this into over-damping (collapse), small noise lets it blow up. No stable middle was found across + aggressiveness × horizon. + +The only interventions that reliably move autonomous amplitude **constrain the output amplitude +directly** (rescale; envelope), not the vector field. + +## Recipe (what's in the pipeline) + +Because the seed dominates and amplitude is a free gauge: + +1. **Select over seeds, not λ.** Fix λ, train several seeds, pick the best on a held-out validation + shard by **rescaled** autonomy (spectral shape + pitch + boundedness; amplitude gauge removed). +2. **Rescale amplitude at generation.** Run the deterministic closed-loop rollout, then match its RMS + to a reference loudness. + +Both are folded into the pipeline: + +- `train.eval.autonomy_score(..., rescale=True)` — selection metric with amplitude gauge-fixed. +- `train.eval.generate_autonomous(model, audio, dt, rescale=True, ref_rms=...)` — deployed generation. +- `train.model_cv.model_cv_lambdas(..., n_seeds=N, selection="autonomy", rescale_autonomy=True, + keep_const=True, lambdas=[λ])` — multi-seed selection by rescaled validation autonomy. +- `examples/run_lambda_pipeline.py` — end-to-end: file-level holdout, seed selection, and a rescaled + autonomous reconstruction (`selected_autonomous_recon.wav` + `selected_model.json`). + +``` +python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' --out-dir ./poly_pipeline \ + --n-epochs 50 --n-seeds 5 --lam 1.068 --drive-lowpass-ms 1.0 --d-state 4 +``` + +## Open direction + +A genuine fix would place an **attracting limit cycle at the data amplitude** — i.e. make the data +orbit a near-exact periodic solution *and* contracting. That is a global property that neither a cheap +divergence penalty nor a (fragile) rollout fine-tune achieves. A phase-robust objective that constrains +the *time-varying loudness envelope* without differentiating through an unstable rollout (e.g. matching +envelopes in a way that is decoupled from the carrier phase) is the most promising untried lever; the +envelope-refine experiments are a first, seed-variable step in that direction. + +## Reproduce the diagnostics + +- `examples/floquet_amplitude_diagnostic.py` — Λ along the data orbit vs. actual rollout decay/growth. +- `examples/eval_autonomy_rescale.py` — raw vs. rescale-bound autonomy + envelope-match metrics. +- `examples/finetune_attractor_poly.py` — the `floq` (Λ-penalty) and `noise` fine-tunes (negative results). +- `examples/finetune_rollout_spectral_poly.py` / `train/rollout_refine.py` — spectral+envelope refine. diff --git a/examples/compare_lambda_autonomous.py b/examples/compare_lambda_autonomous.py new file mode 100644 index 0000000..7a8f900 --- /dev/null +++ b/examples/compare_lambda_autonomous.py @@ -0,0 +1,123 @@ +""" +Compare lambda-sweep poly Ouroboros models on AUTONOMOUS generation (not teacher-forced R^2). + +For each lambda checkpoint, runs closed-loop integration on one or more held-out vocalizations +(deterministic and noise-driven) and plots, vs lambda: + - amplitude (std): deterministic & noise-driven, vs target + - decay ratio (2nd-half / 1st-half std): >1 grows, <1 decays, ~1 self-sustained + - pitch: deterministic resonance & noise-driven, vs target + +Use --n-vocs 1 for a quick draft, larger for a robust (averaged) comparison. + +Run from the repo root: + python -m examples.compare_lambda_autonomous --sweep-dir poly_lp_sweep \ + --data-glob 'data500/gabo_p*' --n-vocs 5 --n 2500 --noise-sd 2.4e-4 +""" + +import argparse +import glob +import os + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np +from scipy.io import wavfile +from scipy.signal import welch + +from train.train import load_model +from train.eval import integrate_poly_autonomous, correct + +plt.rcParams["text.usetex"] = False + + +def pk(x, sr): + x = np.nan_to_num(x) + f, P = welch(x - x.mean(), fs=sr, nperseg=min(1024, len(x))) + P[0] = 0 + return float(f[np.argmax(P)]) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--sweep-dir", default="poly_lp_sweep") + p.add_argument("--data-glob", default="data500/gabo_p*") + p.add_argument("--n-vocs", type=int, default=1) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--n", type=int, default=2500, help="samples to integrate per voc") + p.add_argument("--noise-sd", type=float, default=2.4e-4) + p.add_argument("--out-dir", default="poly_lp_sweep") + args = p.parse_args() + + lam_dirs = sorted(glob.glob(os.path.join(args.sweep_dir, "lam_*")), + key=lambda d: float(d.split("lam_")[-1])) + data_dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + # one held-out vocalization per data dir (gabo_artificial_0), up to n_vocs + vocs = [] + for d in data_dirs[: args.n_vocs]: + wav = os.path.join(d, "gabo_artificial_0.wav") + seg_txt = os.path.join(d, "gabo_artificial_0.txt") + if os.path.isfile(wav): + vocs.append((wav, seg_txt)) + print(f"{len(lam_dirs)} lambdas x {len(vocs)} vocs; n={args.n} samples", flush=True) + + def load_seg(wav, seg_txt, sr_hint=40000): + sr, af = wavfile.read(wav) + af = af.astype(float) + on = np.atleast_2d(np.loadtxt(seg_txt))[0][0] + s = int(on * sr) + int(args.start_offset_ms / 1e3 * sr) + return sr, af[s:s + args.n] + + lams, det_std, det_decay, det_res, nz_std, nz_pitch, tgt_std, tgt_pitch = ([] for _ in range(8)) + for ld in lam_dirs: + lam = float(ld.split("lam_")[-1]) + model, _, _, _ = load_model(ld); model.eval(); dt = model.tau + ds, dd, dr, ns, npi, ts, tp = [], [], [], [], [], [], [] + for wav, seg_txt in vocs: + sr, seg = load_seg(wav, seg_txt) + tgt = correct(seg) + det = integrate_poly_autonomous(model, seg, dt, noise_sd=0.0, detrend=True, verbose=False) + nz = integrate_poly_autonomous(model, seg, dt, noise_sd=args.noise_sd, seed=0, detrend=True, verbose=False) + h1, h2 = det[: len(det) // 2], det[len(det) // 2:] + ds.append(np.nanstd(det)); dd.append(np.nanstd(h2) / (np.nanstd(h1) + 1e-9)) + dr.append(pk(h1, sr)); ns.append(np.nanstd(nz)); npi.append(pk(nz, sr)) + ts.append(tgt.std()); tp.append(pk(tgt, sr)) + lams.append(lam) + det_std.append(ds); det_decay.append(dd); det_res.append(dr) + nz_std.append(ns); nz_pitch.append(npi); tgt_std.append(ts); tgt_pitch.append(tp) + print(f"lam={lam:.3f}: det std={np.mean(ds):.4f} decay={np.mean(dd):.2f} res={np.mean(dr):.0f}Hz " + f"| noise std={np.mean(ns):.4f} pitch={np.mean(npi):.0f}Hz", flush=True) + + lams = np.array(lams) + mean = lambda L: np.array([np.mean(v) for v in L]) + sem = lambda L: np.array([np.std(v) / max(1, np.sqrt(len(v))) for v in L]) + tgt_s = np.mean([np.mean(v) for v in tgt_std]); tgt_p = np.mean([np.mean(v) for v in tgt_pitch]) + + fig, (a1, a2, a3) = plt.subplots(3, 1, figsize=(9, 9), sharex=True) + a1.errorbar(lams, mean(det_std), sem(det_std), marker="o", color="tab:blue", label="deterministic") + a1.errorbar(lams, mean(nz_std), sem(nz_std), marker="s", color="tab:green", label=f"noise-driven (σ={args.noise_sd:g})") + a1.axhline(tgt_s, ls="--", color="tab:orange", label="target") + a1.set_ylabel("autonomous amplitude (std)"); a1.set_yscale("log"); a1.legend(fontsize=8) + a1.set_title("Autonomous generation vs kernel-weight λ (low-pass poly)") + + a2.errorbar(lams, mean(det_decay), sem(det_decay), marker="o", color="tab:blue") + a2.axhline(1.0, ls="--", color="0.5", label="self-sustained (=1)") + a2.set_ylabel("decay ratio (2nd/1st half)"); a2.legend(fontsize=8) + + a3.errorbar(lams, mean(det_res), sem(det_res), marker="o", color="tab:blue", label="det. resonance") + a3.errorbar(lams, mean(nz_pitch), sem(nz_pitch), marker="s", color="tab:green", label="noise-driven pitch") + a3.axhline(tgt_p, ls="--", color="tab:orange", label="target pitch") + a3.set_ylabel("pitch (Hz)"); a3.set_xlabel("kernel-weight λ"); a3.legend(fontsize=8) + + fig.tight_layout() + tag = "draft" if args.n_vocs == 1 else f"{args.n_vocs}voc" + out = os.path.join(os.path.abspath(args.out_dir), f"lambda_autonomous_{tag}.png") + fig.savefig(out, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved figure to {out}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/eval_alpha_autonomy.py b/examples/eval_alpha_autonomy.py new file mode 100644 index 0000000..b59d50b --- /dev/null +++ b/examples/eval_alpha_autonomy.py @@ -0,0 +1,69 @@ +""" +Quick A/B: does the constant (0,0) "alpha" forcing term help AUTONOMOUS reconstruction? + +Scores the +alpha (zero-init, 1ms low-pass) poly Ouroboros against the no-alpha baseline on the +SAME held-out gabo_p9 windows, with train.eval.autonomy_score (deterministic closed-loop, log-PSD +spectral corr minus amp/pitch log-ratio penalties). Caveat: the available baseline is at 2ms low-pass, +not 1ms, so this is the no-alpha point we have, not a perfectly matched control. + + python -m examples.eval_alpha_autonomy --data-dir data500/gabo_p9 --n-vocs 6 +""" + +import argparse +import glob +import os + +import numpy as np +from scipy.io import wavfile + +from train.train import load_model +from train.eval import autonomy_score + + +def load_voc_windows(data_dir, n_vocs, start_offset_ms, n): + segs = [] + for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n_vocs]: + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + onoffs = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt"))) + s = int(onoffs[0][0] * sr) + int(start_offset_ms / 1e3 * sr) + seg = af[s:s + n] + if len(seg) == n: + segs.append(seg) + return segs, sr + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data-dir", default="data500/gabo_p9") + p.add_argument("--n-vocs", type=int, default=6) + p.add_argument("--auto-n", type=int, default=3000) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--alpha-dir", default="poly_alpha/poly") + p.add_argument("--baseline-dir", default="poly_pipeline/poly_lam1.068_seed0") + args = p.parse_args() + + segs, sr = load_voc_windows(args.data_dir, args.n_vocs, args.start_offset_ms, args.auto_n) + print(f"{len(segs)} held-out vocs from {args.data_dir} (n={args.auto_n}, sr={sr})", flush=True) + + runs = [("+alpha (1ms, zero-init)", args.alpha_dir), + ("no-alpha baseline (2ms)", args.baseline_dir)] + for label, d in runs: + model, _, _, _ = load_model(d) + model.eval() + dt = model.tau + keep = getattr(model, "keep_const", False) + lp = getattr(model, "drive_lowpass_ms", None) + mean_score, per_seg, bd = autonomy_score(model, segs, dt) + print(f"\n=== {label} ===", flush=True) + print(f" keep_const={keep} drive_lowpass_ms={lp}", flush=True) + print(f" autonomy score = {mean_score:+.3f}", flush=True) + print(f" spectral corr = {bd['spec_corr']:.3f}", flush=True) + print(f" amp penalty = {bd['amp_pen']:.3f}", flush=True) + print(f" pitch penalty = {bd['pitch_pen']:.3f}", flush=True) + print(f" bounded fraction = {bd['bounded_frac']:.2f}", flush=True) + print(f" per-voc scores = {np.round(per_seg, 3)}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/eval_autonomy_rescale.py b/examples/eval_autonomy_rescale.py new file mode 100644 index 0000000..492f350 --- /dev/null +++ b/examples/eval_autonomy_rescale.py @@ -0,0 +1,129 @@ +""" +Evaluate autonomous reconstruction with the rescale baseline made explicit. + +For each labeled poly checkpoint dir, runs DETERMINISTIC autonomous integration on held-out +gabo_p9 windows and reports three numbers: + + - raw autonomy = spec_corr - amp_pen - pitch_pen (train.eval.autonomy_score's metric) + - rescale bound = spec_corr - pitch_pen (amplitude rescaled to target RMS; + amp_pen -> 0, spec_corr is scale-invariant). This is the cheap baseline the + spectral/envelope fine-tune must BEAT to be worth it. + - env match = corr + relL1 of the time-varying loudness envelope, AFTER rescaling auto to + target RMS -- i.e. the loudness SHAPE that a single global rescale CANNOT fix. + This is where a successful envelope fine-tune should show a distinctive gain. + + python -m examples.eval_autonomy_rescale --data-dir data500/gabo_p9 --n-vocs 6 \ + --models "pre=poly_alpha_ms/alpha_off_seed0/poly" "post=poly_ftspec_seed0/poly" +""" + +import argparse +import glob +import os + +import numpy as np +from scipy.io import wavfile +from scipy.signal import welch + +from train.train import load_model +from train.eval import integrate_poly_autonomous, correct + + +def load_voc_windows(data_dir, n_vocs, off_ms, n): + segs = [] + for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n_vocs]: + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + on = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt")))[0][0] + s = int(on * sr) + int(off_ms / 1e3 * sr) + seg = af[s:s + n] + if len(seg) == n: + segs.append(seg) + return segs, sr + + +def envelope(x, sr, env_ms=2.0): + """Gaussian low-pass of |x| (numpy).""" + dt = 1.0 / sr + sigma = (env_ms / 1e3) / dt + radius = max(1, int(round(3 * sigma))) + t = np.arange(-radius, radius + 1) + k = np.exp(-0.5 * (t / sigma) ** 2) + k /= k.sum() + xr = np.pad(np.abs(x), radius, mode="reflect") + return np.convolve(xr, k, mode="valid") + + +def score_model(model, segs, sr, fmax=8000.0): + """ + Deterministic autonomous reconstruction metrics on held-out windows. Returns a dict: + raw = spec_corr - amp_pen - pitch_pen (autonomy_score metric) + rescale = spec_corr - pitch_pen (amplitude gauge-fixed; bar to beat) + spec, amp_pen, pitch_pen (means) + envcorr, envL1 (loudness-trajectory match after rescale) + """ + dt = model.tau + + def logpsd(x): + f, Pxx = welch(x - np.mean(x), fs=sr, nperseg=min(1024, len(x))) + m = f <= fmax + return np.log(Pxx[m] + 1e-20) + + def peak(x): + f, Pxx = welch(x - np.mean(x), fs=sr, nperseg=min(1024, len(x))) + Pxx[0] = 0 + return float(f[np.argmax(Pxx)]) + + raws, resc, specs, amps, pits, ecorr, eL1 = ([] for _ in range(7)) + for seg in segs: + tgt = correct(np.asarray(seg, dtype=np.float64)) + auto = integrate_poly_autonomous(model, seg, dt, noise_sd=0.0, detrend=True, verbose=False) + n = min(len(tgt), len(auto)) + tgt, auto = tgt[:n], auto[:n] + st, sa = np.nanstd(tgt), np.nanstd(auto) + if (not np.isfinite(auto).all()) or sa < 1e-9: + raws.append(-5.0); resc.append(-5.0) + continue + sc = float(np.corrcoef(logpsd(tgt), logpsd(auto))[0, 1]) + amp = abs(np.log((sa + 1e-12) / (st + 1e-12))) + pit = abs(np.log((peak(auto) + 1e-9) / (peak(tgt) + 1e-9))) + raws.append(sc - amp - pit); resc.append(sc - pit) + specs.append(sc); amps.append(amp); pits.append(pit) + ar = auto * (st / (sa + 1e-12)) + et, ea = envelope(tgt, sr), envelope(ar, sr) + ecorr.append(float(np.corrcoef(et, ea)[0, 1])) + eL1.append(float(np.mean(np.abs(ea - et)) / (np.mean(et) + 1e-9))) + return {"raw": float(np.mean(raws)), "rescale": float(np.mean(resc)), + "spec": float(np.mean(specs)) if specs else float("nan"), + "amp_pen": float(np.mean(amps)) if amps else float("nan"), + "pitch_pen": float(np.mean(pits)) if pits else float("nan"), + "envcorr": float(np.mean(ecorr)) if ecorr else float("nan"), + "envL1": float(np.mean(eL1)) if eL1 else float("nan")} + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data-dir", default="data500/gabo_p9") + p.add_argument("--n-vocs", type=int, default=6) + p.add_argument("--auto-n", type=int, default=3000) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--fmax", type=float, default=8000.0) + p.add_argument("--models", nargs="+", required=True, help='label=dir entries') + args = p.parse_args() + + segs, sr = load_voc_windows(args.data_dir, args.n_vocs, args.start_offset_ms, args.auto_n) + print(f"{len(segs)} held-out vocs from {args.data_dir} (n={args.auto_n}, sr={sr})\n", flush=True) + + print(f"{'label':14s} {'raw':>7s} {'rescale':>8s} {'spec':>6s} {'amp_pen':>8s} " + f"{'pitch':>6s} {'envcorr':>8s} {'envL1':>7s}", flush=True) + for entry in args.models: + label, d = entry.split("=", 1) + model, _, _, _ = load_model(d) + model.eval() + r = score_model(model, segs, sr, fmax=args.fmax) + print(f"{label:14s} {r['raw']:+7.3f} {r['rescale']:+8.3f} {r['spec']:6.3f} " + f"{r['amp_pen']:8.3f} {r['pitch_pen']:6.3f} {r['envcorr']:8.3f} {r['envL1']:7.3f}", + flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/finetune_attractor_poly.py b/examples/finetune_attractor_poly.py new file mode 100644 index 0000000..872bcb1 --- /dev/null +++ b/examples/finetune_attractor_poly.py @@ -0,0 +1,166 @@ +""" +Make the polynomial Ouroboros's cycle ATTRACTING (fix the marginal-amplitude problem), two ways. + +The autonomous amplitude is poorly constrained because on-orbit acceleration-matching leaves the +transverse Floquet exponent Lambda = integral of d f/d x' essentially free (~0, marginal). Two +fine-tunes that target the transverse/off-orbit dynamics: + + --method floq : add an amplitude-STABILITY penalty (no rollout). Evaluate the divergence field + d(x,x') = d f/d x' = -gamma - sum_{p,k>=1} w_pk x^p k x'^{k-1} at the data orbit + SCALED to (1+delta)*radius and (1-delta)*radius (same drives), and push the outer + one negative (contract from outside) and the inner one positive (expand from + inside). This makes the data radius a stable limit cycle WITHOUT imposing global + decay (a constant-negative Lambda would just damp to zero). Loss = TF + lam_floq*pen. + + --method noise: short-horizon NOISY teacher forcing. Perturb the start state off the orbit + (isotropic Gaussian, incl. transverse), roll out H << ... ~1-2 periods (short + enough that phase drift is negligible so pointwise MSE is valid), and require the + trajectory to track the TRUE orbit. The carried-forward perturbation makes the + model learn to contract back to the data orbit (Lambda<0 at the right radius). + Loss = pointwise rollout MSE (from noisy start) + lam_tf*TF. + + python -m examples.finetune_attractor_poly --init poly_ro_pipeline/seed0/tf --method floq \ + --out-dir ./poly_attr_floq_seed0 +""" + +import argparse +import glob +import os + +import numpy as np +import torch +from torch.optim import Adam + +from utils import deriv_approx_dy, deriv_approx_d2y +from train.train import load_model, save_model +from train.rollout_refine import gather_windows + +BX, BXP = 0.5, 1.0 + + +def divergence(x, v, gamma, weights, powers): + """d f/d x' along the trajectory. x,v,gamma:(B,L); weights:(B,L,P,P); -> (B,L).""" + xp = x.unsqueeze(-1) ** powers # (B,L,P) x^p + vp = v.unsqueeze(-1) ** powers # (B,L,P) v^k + dvk = torch.zeros_like(vp) + dvk[..., 1:] = powers[1:] * vp[..., :-1] # k v^{k-1} + dkern = torch.einsum("blpk,blp,blk->bl", weights, xp, dvk) + return -gamma - dkern + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--init", required=True) + p.add_argument("--method", choices=["floq", "noise"], required=True) + p.add_argument("--out-dir", required=True) + p.add_argument("--data-glob", default="data500/gabo_p[0-7]") + p.add_argument("--n-windows", type=int, default=24) + p.add_argument("--l-seg", type=int, default=1500) + p.add_argument("--epochs", type=int, default=16) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--lr", type=float, default=1e-4) + p.add_argument("--lam-tf", type=float, default=1.0) + p.add_argument("--start-offset-ms", type=float, default=50.0) + # floq + p.add_argument("--lam-floq", type=float, default=5.0) + p.add_argument("--delta", type=float, default=0.3, help="radius perturbation for stability penalty") + # noise + p.add_argument("--hmin", type=int, default=12) + p.add_argument("--hmax", type=int, default=30) + p.add_argument("--noise-frac", type=float, default=0.3, help="start-state noise as fraction of signal std") + args = p.parse_args() + + run_dir = os.path.join(os.path.abspath(args.out_dir), "poly") + os.makedirs(run_dir, exist_ok=True) + model, _, _, ep0 = load_model(args.init) + model.train() + dt = model.tau + + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + X, _ = gather_windows(dirs, args.n_windows, args.l_seg, args.start_offset_ms) + D1 = deriv_approx_dy(X) + D2 = deriv_approx_d2y(X) + Xt = torch.tensor(X, dtype=torch.float32, device="cuda") + Dt = torch.tensor(D1, dtype=torch.float32, device="cuda") + D2t = torch.tensor(D2, dtype=torch.float32, device="cuda") + var_d2 = float(D2t.var()) + P = model.kernel.poly_dim + 1 + powers = torch.arange(P, device="cuda", dtype=torch.float32) + opt = Adam(model.parameters(), lr=args.lr) + N = Xt.shape[0] + H_sched = np.unique(np.round(np.geomspace(args.hmin, args.hmax, args.epochs)).astype(int)) + print(f"attractor FT [{args.method}] {args.init} (ep {ep0}); {N} windows L={args.l_seg}", flush=True) + + for epoch in range(args.epochs): + H = int(H_sched[min(epoch, len(H_sched) - 1)]) + perm = torch.randperm(N) + tot = {"tf": 0.0, "main": 0.0} + nb = 0 + for i in range(0, N, args.batch_size): + idx = perm[i:i + args.batch_size] + x, dxd, d2b = Xt[idx], Dt[idx], D2t[idx] + z2 = (model.tau / dt) * dxd + omega, gamma, wk, weights, _ = model.get_funcs(x, dxd.clone(), dt) + tf = -(omega ** 2) * x - gamma * z2 - wk + L_tf = ((tf - d2b) ** 2).mean() / var_d2 + + om, ga, w = omega[:, :, 0], gamma[:, :, 0], weights + xx, vv = x[:, :, 0], z2[:, :, 0] + + if args.method == "floq": + # divergence at the orbit scaled to (1+/-delta)*radius (same drives) + d_out = divergence((1 + args.delta) * xx, (1 + args.delta) * vv, ga, w, powers) + d_in = divergence((1 - args.delta) * xx, (1 - args.delta) * vv, ga, w, powers) + # integrate to the Floquet exponent (per-sample divergence is ~1e-3; the summed + # Lambda is O(1) and is what actually predicts amplitude change) + Lam_out = d_out.sum(dim=1) # (B,) Floquet exponent at outer radius + Lam_in = d_in.sum(dim=1) # (B,) at inner radius + # want Lam_out <= 0 (contract from outside), Lam_in >= 0 (expand from inside) + L_main = (torch.relu(Lam_out) ** 2).mean() + (torch.relu(-Lam_in) ** 2).mean() + loss = L_main * args.lam_floq + args.lam_tf * L_tf + else: # noise + B = xx.shape[0] + sx = args.noise_frac * xx.std() + sv = args.noise_frac * vv.std() + xc = (xx[:, 0] + sx * torch.randn(B, device="cuda")).detach() + xp = (vv[:, 0] + sv * torch.randn(B, device="cuda")).detach() + + def f(xa, va, k): + xpw = xa.unsqueeze(1) ** powers + xvw = va.unsqueeze(1) ** powers + kern = torch.einsum("bp,bk,bpk->b", xpw, xvw, w[:, k]) + return va, -(om[:, k] ** 2) * xa - ga[:, k] * va - kern + + xs = [xc] + for k in range(H - 1): + k1x, k1v = f(xc, xp, k) + k2x, k2v = f(xc + 0.5 * k1x, xp + 0.5 * k1v, k) + k3x, k3v = f(xc + 0.5 * k2x, xp + 0.5 * k2v, k) + k4x, k4v = f(xc + k3x, xp + k3v, k) + xc = xc + (k1x + 2 * k2x + 2 * k3x + k4x) / 6 + xp = xp + (k1v + 2 * k2v + 2 * k3v + k4v) / 6 + xc = BX * torch.tanh(xc / BX) + xp = BXP * torch.tanh(xp / BXP) + xs.append(xc) + xg = torch.stack(xs, dim=1) # (B,H) noisy rollout + tgt = xx[:, :H] # (B,H) true orbit + L_main = ((xg - tgt) ** 2).mean() / (xx.var() + 1e-12) + loss = L_main + args.lam_tf * L_tf + + opt.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) + opt.step() + tot["tf"] += float(L_tf); tot["main"] += float(L_main) + nb += 1 + print(f"[ep {epoch + 1}/{args.epochs} H={H}] {args.method}={tot['main']/nb:.4f} tf={tot['tf']/nb:.4f}", flush=True) + + cfg = model.omega_mamba.config + save_model(model, opt, os.path.join(run_dir, f"checkpoint_{ep0 + args.epochs}.tar"), + n_layers=cfg.n_layers, d_state=cfg.d_state, d_conv=cfg.d_conv, + expand_factor=cfg.expand_factor) + print(f"saved [{args.method}] model to {run_dir}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/finetune_rollout_poly.py b/examples/finetune_rollout_poly.py new file mode 100644 index 0000000..a104331 --- /dev/null +++ b/examples/finetune_rollout_poly.py @@ -0,0 +1,142 @@ +""" +Short rollout fine-tune for the polynomial Ouroboros (controlled experiment). + +Backprops through a SHORT autonomous rollout (started from the data initial state, so phase +stays aligned and the pointwise MSE pins amplitude + frequency rather than collapsing on +long-horizon phase drift), with SOFT state saturation (x <- B*tanh(x/B), so gradients flow +even if a step would diverge -- fixing the hard-clamp bug from the earlier Arneodo attempt), +plus a teacher-forced anchor and a curriculum on the rollout horizon. + +Poly RHS (rescaled time, drives low-passed by the model): d2x/ds2 = -omega^2 x - gamma x' +- sum_ij w_ij x^i x'^j, with omega(t), gamma(t), w(t) from the (differentiable) encoder. + +Run from the repo root: + python -m examples.finetune_rollout_poly --init poly_pipeline/poly_lam1.068_seed0 \ + --data-glob 'data500/gabo_p[0-7]' --out-dir ./poly_ft_seed0 +""" + +import argparse +import glob +import os + +import numpy as np +import torch +from torch.optim import Adam +from scipy.io import wavfile + +from utils import deriv_approx_dy, deriv_approx_d2y +from train.train import load_model, save_model + +BX, BXP = 0.5, 1.0 # soft-saturation bounds (>> data/limit-cycle scale; only tame divergence) + + +def gather_windows(dirs, n, L, start_off_ms): + segs, sr = [], None + for d in dirs: + for wav in sorted(glob.glob(os.path.join(d, "*.wav"))): + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + on = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt")))[0][0] + s = int(on * sr) + int(start_off_ms / 1e3 * sr) + seg = af[s:s + L] + if len(seg) == L: + segs.append(seg) + if len(segs) >= n: + return np.stack(segs)[:, :, None], sr + return np.stack(segs)[:, :, None], sr + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--init", required=True, help="poly checkpoint dir to fine-tune") + p.add_argument("--data-glob", default="data500/gabo_p[0-7]") + p.add_argument("--out-dir", required=True) + p.add_argument("--n-windows", type=int, default=48) + p.add_argument("--l-seg", type=int, default=1500, help="encoder context length") + p.add_argument("--hmax", type=int, default=400, help="max rollout horizon (samples)") + p.add_argument("--epochs", type=int, default=15) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--lr", type=float, default=1e-4) + p.add_argument("--lam-tf", type=float, default=1.0) + p.add_argument("--start-offset-ms", type=float, default=50.0) + args = p.parse_args() + + run_dir = os.path.join(os.path.abspath(args.out_dir), "poly") + os.makedirs(run_dir, exist_ok=True) + model, _, _, ep0 = load_model(args.init) + model.train() + dt = model.tau + + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + X, sr = gather_windows(dirs, args.n_windows, args.l_seg, args.start_offset_ms) + D1 = deriv_approx_dy(X) + D2 = deriv_approx_d2y(X) + Xt = torch.tensor(X, dtype=torch.float32, device="cuda") + Dt = torch.tensor(D1, dtype=torch.float32, device="cuda") + D2t = torch.tensor(D2, dtype=torch.float32, device="cuda") + var_x, var_d2 = float(Xt.var()), float(D2t.var()) + P = model.kernel.poly_dim + 1 + powers = torch.arange(P, device="cuda") + opt = Adam(model.parameters(), lr=args.lr) + N = Xt.shape[0] + H_sched = np.unique(np.round(np.geomspace(80, args.hmax, args.epochs)).astype(int)) + print(f"fine-tuning {args.init} (epoch {ep0}); {N} windows L={args.l_seg} Hmax={args.hmax}", flush=True) + + for epoch in range(args.epochs): + H = int(H_sched[min(epoch, len(H_sched) - 1)]) + perm = torch.randperm(N) + tot_r = tot_t = 0.0 + nb = 0 + for i in range(0, N, args.batch_size): + idx = perm[i:i + args.batch_size] + x, dxd, d2b = Xt[idx], Dt[idx], D2t[idx] + z2 = (model.tau / dt) * dxd + omega, gamma, wk, weights, _ = model.get_funcs(x, dxd.clone(), dt) # differentiable + tf = -(omega ** 2) * x - gamma * z2 - wk + L_tf = ((tf - d2b) ** 2).mean() / var_d2 + + om = omega[:, :, 0] + ga = gamma[:, :, 0] + w = weights # (B, L, P, P) + + def f(xx, vv, k): + xpw = xx.unsqueeze(1) ** powers + xvw = vv.unsqueeze(1) ** powers + kern = torch.einsum("bp,bk,bpk->b", xpw, xvw, w[:, k]) + return vv, -(om[:, k] ** 2) * xx - ga[:, k] * vv - kern + + xc = x[:, 0, 0].detach() + xp = z2[:, 0, 0].detach() + xs = [xc] + for k in range(H - 1): + k1x, k1v = f(xc, xp, k) + k2x, k2v = f(xc + 0.5 * k1x, xp + 0.5 * k1v, k) + k3x, k3v = f(xc + 0.5 * k2x, xp + 0.5 * k2v, k) + k4x, k4v = f(xc + k3x, xp + k3v, k) + xc = xc + (k1x + 2 * k2x + 2 * k3x + k4x) / 6 + xp = xp + (k1v + 2 * k2v + 2 * k3v + k4v) / 6 + xc = BX * torch.tanh(xc / BX) + xp = BXP * torch.tanh(xp / BXP) + xs.append(xc) + xg = torch.stack(xs, dim=1) # (B, H) + L_roll = ((xg - x[:, :H, 0]) ** 2).mean() / var_x + + loss = L_roll + args.lam_tf * L_tf + opt.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) + opt.step() + tot_r += float(L_roll) + tot_t += float(L_tf) + nb += 1 + print(f"[ep {epoch + 1}/{args.epochs} H={H}] relMSE roll={tot_r / nb:.4f} tf={tot_t / nb:.4f}", flush=True) + + cfg = model.omega_mamba.config + save_model(model, opt, os.path.join(run_dir, f"checkpoint_{ep0 + args.epochs}.tar"), + n_layers=cfg.n_layers, d_state=cfg.d_state, d_conv=cfg.d_conv, + expand_factor=cfg.expand_factor) + print(f"saved fine-tuned model to {run_dir}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/finetune_rollout_spectral_poly.py b/examples/finetune_rollout_spectral_poly.py new file mode 100644 index 0000000..cf21ea6 --- /dev/null +++ b/examples/finetune_rollout_spectral_poly.py @@ -0,0 +1,61 @@ +""" +Spectral + envelope rollout fine-tune for the polynomial Ouroboros (single-model entry point). + +Thin wrapper around train.rollout_refine.rollout_refine (the reusable objective): loads a trained +poly checkpoint, gathers held-out sustained windows, and applies the phase-invariant +multi-resolution-STFT + envelope rollout objective. See train/rollout_refine.py for the loss. + +Run from the repo root: + python -m examples.finetune_rollout_spectral_poly --init poly_alpha_ms/alpha_off_seed0/poly \ + --data-glob 'data500/gabo_p[0-7]' --out-dir ./poly_ftspec_seed0 --lam-env 10 +""" + +import argparse +import glob +import os + +from train.train import load_model, save_model +from train.rollout_refine import gather_windows, rollout_refine + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--init", required=True, help="poly checkpoint dir to fine-tune") + p.add_argument("--data-glob", default="data500/gabo_p[0-7]") + p.add_argument("--out-dir", required=True) + p.add_argument("--n-windows", type=int, default=12) + p.add_argument("--l-seg", type=int, default=1500, help="encoder context length (>= hmax)") + p.add_argument("--hmax", type=int, default=1500, help="max rollout horizon (samples)") + p.add_argument("--hmin", type=int, default=768, help="initial rollout horizon (curriculum)") + p.add_argument("--epochs", type=int, default=8) + p.add_argument("--batch-size", type=int, default=6) + p.add_argument("--lr", type=float, default=1e-4) + p.add_argument("--lam-tf", type=float, default=1.0) + p.add_argument("--lam-spec", type=float, default=1.0) + p.add_argument("--lam-env", type=float, default=10.0) + p.add_argument("--env-ms", type=float, default=2.0) + p.add_argument("--start-offset-ms", type=float, default=50.0) + args = p.parse_args() + + run_dir = os.path.join(os.path.abspath(args.out_dir), "poly") + os.makedirs(run_dir, exist_ok=True) + model, _, _, ep0 = load_model(args.init) + dt = model.tau + + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + X, _ = gather_windows(dirs, args.n_windows, args.l_seg, args.start_offset_ms) + print(f"spectral FT {args.init} (epoch {ep0}); {X.shape[0]} windows L={args.l_seg}", flush=True) + + _, opt = rollout_refine(model, X, dt, epochs=args.epochs, hmin=args.hmin, hmax=args.hmax, + batch_size=args.batch_size, lr=args.lr, lam_spec=args.lam_spec, + lam_env=args.lam_env, lam_tf=args.lam_tf, env_ms=args.env_ms) + + cfg = model.omega_mamba.config + save_model(model, opt, os.path.join(run_dir, f"checkpoint_{ep0 + args.epochs}.tar"), + n_layers=cfg.n_layers, d_state=cfg.d_state, d_conv=cfg.d_conv, + expand_factor=cfg.expand_factor) + print(f"saved fine-tuned model to {run_dir}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/floquet_amplitude_diagnostic.py b/examples/floquet_amplitude_diagnostic.py new file mode 100644 index 0000000..0b4a999 --- /dev/null +++ b/examples/floquet_amplitude_diagnostic.py @@ -0,0 +1,132 @@ +""" +Amplitude-Floquet diagnostic: is the model's learned cycle attracting, neutral, or repelling? + +For a 2-D oscillator d2x/ds2 = f(x, x'; drives), the non-trivial Floquet multiplier over the orbit is +exp(Lambda) with the area-contraction exponent + + Lambda = integral over the window of d f / d x' (x' = dx/ds, the model's internal velocity) + +evaluated ALONG THE DATA ORBIT with the model's drives. Lambda<0 -> attracting (amplitude decays toward +the cycle), ~0 -> neutral (SHO-like, amplitude free), >0 -> repelling (amplitude grows). The predicted +log change in oscillation amplitude over the window is ~ Lambda/2 (area ~ amplitude^2). + +f = -omega^2 x - gamma x' - sum_{p,k} w_{pk} x^p (x')^k, so + d f / d x' = -gamma - sum_{p,k>=1} w_{pk} x^p k (x')^{k-1}. + +Prediction to test: sign(Lambda) should match the actual rollout decay/growth (2nd-half/1st-half std), +and the spread of Lambda across seeds should track the spread of autonomous amplitude. + + python -m examples.floquet_amplitude_diagnostic --data-dir data500/gabo_p9 --n-vocs 6 \ + --runs "seed0=poly_ro_pipeline/seed0/tf" ... +""" + +import argparse +import glob +import os + +import numpy as np +import torch +from scipy.io import wavfile +from scipy.signal import welch + +from utils import deriv_approx_dy +from train.train import load_model +from train.eval import integrate_poly_autonomous, correct + + +def load_voc_windows(data_dir, n_vocs, off_ms, n): + segs = [] + for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n_vocs]: + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + on = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt")))[0][0] + s = int(on * sr) + int(off_ms / 1e3 * sr) + seg = af[s:s + n] + if len(seg) == n: + segs.append(seg) + return segs, sr + + +def peak_freq(x, sr): + f, P = welch(x - np.mean(x), fs=sr, nperseg=min(1024, len(x))) + P[0] = 0 + return float(f[np.argmax(P)]) + + +def floquet_exponent(model, seg, dt): + """Lambda = sum_s d f/d x' along the data orbit (rescaled time, ds=1 per sample).""" + X = np.asarray(seg, dtype=np.float64)[None, :, None] + D1 = deriv_approx_dy(X) + Xt = torch.tensor(X, dtype=torch.float32, device="cuda") + Dt = torch.tensor(D1, dtype=torch.float32, device="cuda") + with torch.no_grad(): + omega, gamma, wk, weights, _ = model.get_funcs(Xt, Dt.clone(), dt) + z2 = (model.tau / dt) * Dt # internal velocity x' + x = Xt[0, :, 0] # (L,) + v = z2[0, :, 0] # (L,) + ga = gamma[0, :, 0] # (L,) + w = weights[0] # (L,P,P) indexed [l,p,k] = w_{x^p (x')^k} + P = w.shape[-1] + powers = torch.arange(P, device="cuda", dtype=torch.float32) + x_pow = x[:, None] ** powers # (L,P) x^p + v_pow = v[:, None] ** powers # (L,P) (x')^k + dvk = torch.zeros_like(v_pow) # d/dx' of (x')^k = k (x')^{k-1} + dvk[:, 1:] = powers[1:][None, :] * v_pow[:, :-1] + dkern_dv = torch.einsum("lpk,lp,lk->l", w, x_pow, dvk) + dfdv = -ga - dkern_dv # (L,) + Lam = float(dfdv.sum()) # ds = 1 per sample + mean_dfdv = float(dfdv.mean()) + return Lam, mean_dfdv + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data-dir", default="data500/gabo_p9") + p.add_argument("--n-vocs", type=int, default=6) + p.add_argument("--auto-n", type=int, default=3000) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--runs", nargs="+", required=True, help="label=dir entries") + args = p.parse_args() + + segs, sr = load_voc_windows(args.data_dir, args.n_vocs, args.start_offset_ms, args.auto_n) + print(f"{len(segs)} held-out vocs from {args.data_dir} (n={args.auto_n}, sr={sr})\n", flush=True) + print(f"{'label':8s} {'Lambda':>8s} {'L/cycle':>8s} {'pred A ratio':>12s} " + f"{'act decay':>10s} {'act std/tgt':>12s}", flush=True) + print(f"{'':8s} {'(area)':>8s} {'':>8s} {'exp(L/2)':>12s} " + f"{'2nd/1st':>10s} {'':>12s}", flush=True) + + Ls, decays = [], [] + for entry in args.runs: + label, d = entry.split("=", 1) + model, _, _, _ = load_model(d) + model.eval() + dt = model.tau + lams, lpc, predA, decay, stdrat = [], [], [], [], [] + for seg in segs: + Lam, _ = floquet_exponent(model, seg, dt) + ncyc = peak_freq(seg, sr) * (len(seg) / sr) + auto = integrate_poly_autonomous(model, seg, dt, noise_sd=0.0, detrend=True, verbose=False) + tgt = correct(np.asarray(seg, dtype=np.float64)) + h1, h2 = auto[:len(auto) // 2], auto[len(auto) // 2:] + lams.append(Lam) + lpc.append(Lam / max(ncyc, 1e-9)) + predA.append(np.exp(np.clip(Lam / 2, -20, 20))) + decay.append(np.nanstd(h2) / (np.nanstd(h1) + 1e-12)) + stdrat.append(np.nanstd(auto) / (np.nanstd(tgt) + 1e-12)) + L = float(np.mean(lams)) + Ls.append(L); decays.append(float(np.mean(decay))) + print(f"{label:8s} {L:+8.2f} {np.mean(lpc):+8.4f} {np.mean(predA):12.3f} " + f"{np.mean(decay):10.3f} {np.mean(stdrat):12.3f}", flush=True) + + Ls, decays = np.array(Ls), np.array(decays) + if len(Ls) >= 3: + # sign agreement: Lambda<0 should give decay<1, Lambda>0 decay>1 + agree = np.mean((np.sign(Ls) == np.sign(np.log(decays + 1e-12)))) + r = float(np.corrcoef(Ls, np.log(decays + 1e-12))[0, 1]) + print(f"\nLambda vs log(decay ratio): sign agreement {agree*100:.0f}%, corr r={r:+.2f}", flush=True) + print(f"Lambda spread across seeds: mean {Ls.mean():+.2f} std {Ls.std():.2f} " + f"[{Ls.min():+.2f},{Ls.max():+.2f}]", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/run_alpha_multiseed.py b/examples/run_alpha_multiseed.py new file mode 100644 index 0000000..f761738 --- /dev/null +++ b/examples/run_alpha_multiseed.py @@ -0,0 +1,158 @@ +""" +Multi-seed A/B: does the constant (0,0) "alpha" forcing term improve AUTONOMOUS reconstruction, +disentangled from the low-pass change and from seed noise? + +Trains TWO arms over the SAME seed set, both at 1ms drive low-pass (matched control): + - alpha_on : --keep-const (the (0,0) y^0*ydot^0 forcing term, zero-init, low-passed with the rest) + - alpha_off: no constant term (the standard poly model), also at 1ms + +Each run is just the tested examples.train_poly_lowpass invoked as a subprocess (same code path that +produced the single-seed +0.372), so this only orchestrates + scores. Resumable: a run whose final +checkpoint already exists is skipped. After training, scores deterministic autonomy on held-out +gabo_p9 windows (train.eval.autonomy_score) and writes a paired comparison figure + CSV. + +Run from the repo root: + python -m examples.run_alpha_multiseed --train-glob 'data500/gabo_p[0-7]' --test-dir data500/gabo_p9 \ + --seeds 0 1 2 3 4 --epochs 40 --lam 1.068 --drive-lowpass-ms 1.0 --out-dir ./poly_alpha_ms +""" + +import argparse +import glob +import os +import re +import subprocess +import sys + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +import numpy as np +from scipy.io import wavfile + +from train.train import load_model +from train.eval import autonomy_score + +plt.rcParams["text.usetex"] = False +PY = sys.executable # the venv python running this driver + + +def load_voc_windows(data_dir, n_vocs, off_ms, n): + segs = [] + for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n_vocs]: + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + on = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt")))[0][0] + s = int(on * sr) + int(off_ms / 1e3 * sr) + seg = af[s:s + n] + if len(seg) == n: + segs.append(seg) + return segs + + +def train_one(arm, seed, args): + """Train one (arm, seed) via examples.train_poly_lowpass unless its final checkpoint exists.""" + out_dir = os.path.join(os.path.abspath(args.out_dir), f"{arm}_seed{seed}") + ckpt = os.path.join(out_dir, "poly", f"checkpoint_{args.epochs}.tar") + if os.path.isfile(ckpt): + print(f"[skip] {arm} seed{seed}: {ckpt} exists", flush=True) + return out_dir + cmd = [PY, "-m", "examples.train_poly_lowpass", + "--data-glob", args.train_glob, "--out-dir", out_dir, + "--lam", str(args.lam), "--drive-lowpass-ms", str(args.drive_lowpass_ms), + "--d-state", str(args.d_state), "--context-len", str(args.context_len), + "--epochs", str(args.epochs), "--seg", "10", + "--batch-size", str(args.batch_size), "--seed", str(seed)] + if arm == "alpha_on": + cmd.append("--keep-const") + print(f"\n[train] {arm} seed{seed}: {' '.join(cmd)}", flush=True) + res = subprocess.run(cmd, capture_output=True, text=True) + sys.stdout.write(res.stdout[-2000:]) + if res.returncode != 0: + sys.stdout.write(res.stderr[-2000:]) + raise RuntimeError(f"training failed: {arm} seed{seed}") + m = re.search(r"best test R2 = ([\-0-9.]+)", res.stdout) + print(f"[done] {arm} seed{seed}: best test R2 = {m.group(1) if m else '?'}", flush=True) + return out_dir + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--train-glob", default="data500/gabo_p[0-7]") + p.add_argument("--test-dir", default="data500/gabo_p9") + p.add_argument("--seeds", type=int, nargs="+", default=[0, 1, 2, 3, 4]) + p.add_argument("--epochs", type=int, default=40) + p.add_argument("--lam", type=float, default=1.068) + p.add_argument("--drive-lowpass-ms", type=float, default=1.0) + p.add_argument("--d-state", type=int, default=4) + p.add_argument("--context-len", type=float, default=0.1) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--n-test-vocs", type=int, default=6) + p.add_argument("--auto-n", type=int, default=3000) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--out-dir", default="./poly_alpha_ms") + args = p.parse_args() + + arms = ["alpha_on", "alpha_off"] + # train all (arm, seed) sequentially (shared GPU -> no parallel) + run_dirs = {} + for seed in args.seeds: + for arm in arms: + run_dirs[(arm, seed)] = train_one(arm, seed, args) + + # score autonomy on the SAME held-out windows + segs = load_voc_windows(args.test_dir, args.n_test_vocs, args.start_offset_ms, args.auto_n) + print(f"\nscoring autonomy on {len(segs)} held-out vocs from {args.test_dir}", flush=True) + results = {arm: [] for arm in arms} + rows = [] + for seed in args.seeds: + for arm in arms: + model, _, _, _ = load_model(os.path.join(run_dirs[(arm, seed)], "poly")) + model.eval() + score, per, bd = autonomy_score(model, segs, model.tau) + results[arm].append(score) + rows.append((arm, seed, score, bd["spec_corr"], bd["amp_pen"], + bd["pitch_pen"], bd["bounded_frac"])) + print(f" {arm:9s} seed{seed}: autonomy={score:+.3f} spec={bd['spec_corr']:.3f} " + f"amp_pen={bd['amp_pen']:.3f} pitch_pen={bd['pitch_pen']:.3f}", flush=True) + + out = os.path.abspath(args.out_dir) + os.makedirs(out, exist_ok=True) + csv = os.path.join(out, "alpha_multiseed.csv") + with open(csv, "w") as f: + f.write("arm,seed,autonomy,spec_corr,amp_pen,pitch_pen,bounded_frac\n") + for r in rows: + f.write(f"{r[0]},{r[1]},{r[2]:.4f},{r[3]:.4f},{r[4]:.4f},{r[5]:.4f},{r[6]:.2f}\n") + + on, off = np.array(results["alpha_on"]), np.array(results["alpha_off"]) + print(f"\n=== SUMMARY over {len(args.seeds)} seeds ===", flush=True) + print(f" alpha_on : mean {on.mean():+.3f} std {on.std():.3f} [{on.min():+.3f},{on.max():+.3f}]", flush=True) + print(f" alpha_off: mean {off.mean():+.3f} std {off.std():.3f} [{off.min():+.3f},{off.max():+.3f}]", flush=True) + print(f" paired diff (on-off): mean {(on-off).mean():+.3f} per-seed {np.round(on-off,3)}", flush=True) + + # figure: paired lines + per-arm distribution + fig, (axL, axR) = plt.subplots(1, 2, figsize=(11, 5)) + for i, s in enumerate(args.seeds): + axL.plot([0, 1], [off[i], on[i]], "-o", color="0.6", zorder=1) + axL.scatter(np.zeros_like(off), off, color="tab:red", zorder=2, label="alpha_off (1ms)") + axL.scatter(np.ones_like(on), on, color="tab:blue", zorder=2, label="alpha_on (1ms)") + axL.set_xticks([0, 1]); axL.set_xticklabels(["alpha_off", "alpha_on"]) + axL.set_xlim(-0.4, 1.4); axL.set_ylabel("autonomy score (held-out gabo_p9)") + axL.set_title("paired by seed"); axL.legend(fontsize=8) + bp = axR.boxplot([off, on], labels=["alpha_off", "alpha_on"], showmeans=True, widths=0.5) + axR.scatter(np.full_like(off, 1), off, color="tab:red", alpha=0.7) + axR.scatter(np.full_like(on, 2), on, color="tab:blue", alpha=0.7) + axR.set_ylabel("autonomy score") + axR.set_title(f"distribution ({len(args.seeds)} seeds, {args.epochs}ep, 1ms low-pass)") + fig.suptitle("Constant (0,0) 'alpha' forcing: autonomous reconstruction, matched 1ms control") + fig.tight_layout() + figpath = os.path.join(out, "alpha_multiseed.png") + fig.savefig(figpath, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"\nSaved CSV -> {csv}", flush=True) + print(f"Saved figure -> {figpath}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/run_lambda_pipeline.py b/examples/run_lambda_pipeline.py new file mode 100644 index 0000000..b634e14 --- /dev/null +++ b/examples/run_lambda_pipeline.py @@ -0,0 +1,130 @@ +""" +Production selection pipeline for the (low-pass) polynomial Ouroboros: pick the best SEED, then +amplitude-rescale at generation. + +Autonomous-reconstruction quality is dominated by the random SEED (init + batch order), not by the +kernel-weight lambda (lambda is irrelevant for autonomy; see docs/autonomous_amplitude.md). So the +recipe is: fix lambda, train several seeds, and SELECT the best seed on the VALIDATION set by +RESCALED autonomy (amplitude is a free overall-scale gauge, fixed at generation by +train.eval.generate_autonomous). The (0,0) polynomial 'alpha' term is kept on (R2-neutral here; may +help other datasets). A FILE-LEVEL holdout is used for the autonomy vocalizations: + - training chunks from most data shards, + - validation autonomy vocs from a held-out shard (seed selection), + - test autonomy vocs from another held-out shard (final report + a rescaled reconstruction wav). + +Pass --lam <=0 to instead sweep the standard 7-point lambda grid. + +Run from the repo root: + python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' --out-dir ./poly_pipeline \ + --n-epochs 50 --n-seeds 5 --lam 1.068 --drive-lowpass-ms 1.0 --d-state 4 +""" + +import argparse +import glob +import json +import os + +import numpy as np +from scipy.io import wavfile + +from data.load_data import get_segmented_audio +from data.data_utils import get_loaders +from train.model_cv import model_cv_lambdas +from train.eval import autonomy_score, generate_autonomous + + +def load_voc_windows(data_dir, n_vocs, start_offset_ms, n): + """held-out sustained vocalization windows (start `start_offset_ms` after onset).""" + segs = [] + for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n_vocs]: + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + onoffs = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt"))) + s = int(onoffs[0][0] * sr) + int(start_offset_ms / 1e3 * sr) + seg = af[s:s + n] + if len(seg) == n: + segs.append(seg) + return segs, sr + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data-glob", default="data500/gabo_p*") + p.add_argument("--out-dir", default="poly_pipeline") + p.add_argument("--n-epochs", type=int, default=50) + p.add_argument("--n-seeds", type=int, default=5) + p.add_argument("--lam", type=float, default=1.068, + help="fixed kernel-weight lambda (lambda is irrelevant for autonomy; the SEED " + "is what matters, so we fix lambda and select over seeds). Pass <=0 to " + "sweep the standard 7-point lambda grid instead.") + p.add_argument("--keep-const", action=argparse.BooleanOptionalAction, default=True, + help="keep the (0,0) polynomial 'alpha' forcing term (R2-neutral on gabo; may " + "help on other datasets)") + p.add_argument("--drive-lowpass-ms", type=float, default=1.0) + p.add_argument("--d-state", type=int, default=4) + p.add_argument("--n-kernels", type=int, default=15) + p.add_argument("--context-len", type=float, default=0.1) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--n-val-vocs", type=int, default=3) + p.add_argument("--n-test-vocs", type=int, default=3) + p.add_argument("--auto-n", type=int, default=3000, help="autonomy window length (samples)") + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--seed", type=int, default=1234) + p.add_argument("--n-jobs", type=int, default=8) + args = p.parse_args() + + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + assert len(dirs) >= 3, "need >=3 data shards for train/val/test holdout" + train_dirs, val_dir, test_dir = dirs[:-2], dirs[-2], dirs[-1] + out_dir = os.path.abspath(args.out_dir) + os.makedirs(out_dir, exist_ok=True) + + # training chunks from the train shards + chunks, sr = [], None + per = 100000 // max(1, len(train_dirs)) + for d in train_dirs: + audio, sr = get_segmented_audio(d, d, max_vocs=per, context_len=args.context_len, + seed=args.seed, training=True, extend=True, shuffle_order=True) + chunks += audio + dt = 1 / sr + dls = get_loaders(np.stack(chunks, 0), num_workers=args.n_jobs, batch_size=args.batch_size, + train_size=0.6, cv=True, seed=args.seed, dt=dt) + + # held-out autonomy vocalizations (file-level holdout) + val_vocs, _ = load_voc_windows(val_dir, args.n_val_vocs, args.start_offset_ms, args.auto_n) + test_vocs, _ = load_voc_windows(test_dir, args.n_test_vocs, args.start_offset_ms, args.auto_n) + print(f"train chunks={len(chunks)} from {len(train_dirs)} shards | " + f"val_vocs={len(val_vocs)} from {os.path.basename(val_dir)} | " + f"test_vocs={len(test_vocs)} from {os.path.basename(test_dir)} | sr={sr}", flush=True) + + best_model = model_cv_lambdas( + dls=dls, dt=dt, n_epochs=args.n_epochs, lr=1e-3, n_kernels=args.n_kernels, + expand_factor=10, n_layers=3, d_state=args.d_state, d_conv=4, tau=dt, + model_path=out_dir, save_freq=max(args.n_epochs // 5, 1), + drive_lowpass_ms=args.drive_lowpass_ms, n_seeds=args.n_seeds, + selection="autonomy", val_vocs=val_vocs, test_vocs=test_vocs, + keep_const=args.keep_const, rescale_autonomy=True, + lambdas=None if args.lam <= 0 else [args.lam], + ) + + # DEPLOYED generation: rescaled autonomous reconstruction of the held-out test vocs. + # (Amplitude is a free gauge fixed at generation; selection above already used rescaled autonomy.) + if test_vocs: + score, _, bd = autonomy_score(best_model, test_vocs, dt, rescale=True) + recon = generate_autonomous(best_model, test_vocs[0], dt, rescale=True) + wav = (recon / (np.abs(recon).max() + 1e-12) * 0.95 * 32767).astype(np.int16) + wavfile.write(os.path.join(out_dir, "selected_autonomous_recon.wav"), int(round(1 / dt)), wav) + manifest = {"selected_lambda": float(args.lam), "n_seeds": int(args.n_seeds), + "keep_const": bool(args.keep_const), "drive_lowpass_ms": float(args.drive_lowpass_ms), + "rescaled_test_autonomy": score, "spec_corr": bd["spec_corr"], + "pitch_pen": bd["pitch_pen"], "bounded_frac": bd["bounded_frac"]} + with open(os.path.join(out_dir, "selected_model.json"), "w") as f: + json.dump(manifest, f, indent=2) + print(f"deployed (rescaled) test autonomy = {score:+.3f} (spec={bd['spec_corr']:.2f}, " + f"pitch_pen={bd['pitch_pen']:.2f}, bounded={bd['bounded_frac']:.2f})", flush=True) + print(f"wrote {out_dir}/selected_autonomous_recon.wav + selected_model.json", flush=True) + print("PIPELINE DONE", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/run_poly_rollout_pipeline.py b/examples/run_poly_rollout_pipeline.py new file mode 100644 index 0000000..0e503c2 --- /dev/null +++ b/examples/run_poly_rollout_pipeline.py @@ -0,0 +1,146 @@ +""" +Multi-seed validation of the integrated TF + spectral/envelope rollout-refine pipeline. + +For each seed: trains the integrated model (examples.train_poly_rollout -> teacher-forced THEN +rollout-refine, saving tf/ and poly/ checkpoints), then scores DETERMINISTIC autonomous +reconstruction on held-out gabo_p9 windows for BOTH the TF-only and the TF+rollout model, against +the rescale baseline. Reports the per-seed and aggregate change, and a paired figure. + +Resumable: a seed whose final (poly/) checkpoint exists is not retrained. + +Run from the repo root: + python -m examples.run_poly_rollout_pipeline --train-glob 'data500/gabo_p[0-7]' \ + --test-dir data500/gabo_p9 --seeds 0 1 2 3 4 --out-dir ./poly_ro_pipeline +""" + +import argparse +import glob +import os +import subprocess +import sys + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +from train.train import load_model +from examples.eval_autonomy_rescale import load_voc_windows, score_model + +plt.rcParams["text.usetex"] = False +PY = sys.executable + + +def latest_ckpt(d): + return sorted(glob.glob(os.path.join(d, "*.tar")), + key=lambda f: int(f.split("checkpoint_")[-1].split(".tar")[0])) + + +def train_one(seed, args): + out_dir = os.path.join(os.path.abspath(args.out_dir), f"seed{seed}") + if latest_ckpt(os.path.join(out_dir, "poly")): + print(f"[skip] seed{seed}: final checkpoint exists", flush=True) + return out_dir + cmd = [PY, "-m", "examples.train_poly_rollout", "--data-glob", args.train_glob, + "--out-dir", out_dir, "--tf-epochs", str(args.tf_epochs), + "--rollout-epochs", str(args.rollout_epochs), "--lam-env", str(args.lam_env), + "--lam-spec", str(args.lam_spec), "--lam-tf", str(args.lam_tf), + "--drive-lowpass-ms", str(args.drive_lowpass_ms), "--lam", str(args.lam), + "--d-state", str(args.d_state), "--context-len", str(args.context_len), + "--batch-size", str(args.batch_size), "--rollout-windows", str(args.rollout_windows), + "--rollout-hmax", str(args.rollout_hmax), "--rollout-l-seg", str(args.rollout_hmax), + "--seed", str(seed)] + print(f"\n[train] seed{seed}: {' '.join(cmd)}", flush=True) + res = subprocess.run(cmd, capture_output=True, text=True) + sys.stdout.write(res.stdout[-1500:]) + if res.returncode != 0: + sys.stdout.write(res.stderr[-2000:]) + raise RuntimeError(f"training failed: seed{seed}") + return out_dir + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--train-glob", default="data500/gabo_p[0-7]") + p.add_argument("--test-dir", default="data500/gabo_p9") + p.add_argument("--seeds", type=int, nargs="+", default=[0, 1, 2, 3, 4]) + p.add_argument("--out-dir", default="./poly_ro_pipeline") + p.add_argument("--tf-epochs", type=int, default=30) + p.add_argument("--rollout-epochs", type=int, default=8) + p.add_argument("--rollout-windows", type=int, default=12) + p.add_argument("--rollout-hmax", type=int, default=1500) + p.add_argument("--lam", type=float, default=1.068) + p.add_argument("--lam-spec", type=float, default=1.0) + p.add_argument("--lam-env", type=float, default=10.0) + p.add_argument("--lam-tf", type=float, default=1.0) + p.add_argument("--drive-lowpass-ms", type=float, default=1.0) + p.add_argument("--d-state", type=int, default=4) + p.add_argument("--context-len", type=float, default=0.1) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--n-test-vocs", type=int, default=6) + p.add_argument("--auto-n", type=int, default=3000) + p.add_argument("--start-offset-ms", type=float, default=50.0) + args = p.parse_args() + + run_dirs = {s: train_one(s, args) for s in args.seeds} + + segs, sr = load_voc_windows(args.test_dir, args.n_test_vocs, args.start_offset_ms, args.auto_n) + print(f"\nscoring autonomy on {len(segs)} held-out vocs from {args.test_dir}\n", flush=True) + print(f"{'seed':6s} {'phase':6s} {'raw':>7s} {'rescale':>8s} {'spec':>6s} {'amp_pen':>8s} " + f"{'envcorr':>8s} {'envL1':>7s}", flush=True) + res = {"tf": [], "ro": []} + rows = [] + for s in args.seeds: + for phase, sub in [("tf", "tf"), ("ro", "poly")]: + ck = latest_ckpt(os.path.join(run_dirs[s], sub)) + if not ck: + print(f" seed{s} {phase}: no checkpoint", flush=True) + continue + model, _, _, _ = load_model(os.path.join(run_dirs[s], sub)) + model.eval() + r = score_model(model, segs, sr) + res[phase].append(r["raw"]) + rows.append((s, phase, r)) + print(f"seed{s:<2d} {phase:6s} {r['raw']:+7.3f} {r['rescale']:+8.3f} {r['spec']:6.3f} " + f"{r['amp_pen']:8.3f} {r['envcorr']:8.3f} {r['envL1']:7.3f}", flush=True) + + out = os.path.abspath(args.out_dir) + os.makedirs(out, exist_ok=True) + with open(os.path.join(out, "rollout_pipeline.csv"), "w") as f: + f.write("seed,phase,raw,rescale,spec,amp_pen,pitch_pen,envcorr,envL1\n") + for s, ph, r in rows: + f.write(f"{s},{ph},{r['raw']:.4f},{r['rescale']:.4f},{r['spec']:.4f},{r['amp_pen']:.4f}," + f"{r['pitch_pen']:.4f},{r['envcorr']:.4f},{r['envL1']:.4f}\n") + + tf, ro = np.array(res["tf"]), np.array(res["ro"]) + if len(tf) and len(ro) and len(tf) == len(ro): + print(f"\n=== SUMMARY over {len(tf)} seeds (autonomy raw) ===", flush=True) + print(f" TF-only : mean {tf.mean():+.3f} std {tf.std():.3f} [{tf.min():+.3f},{tf.max():+.3f}]", flush=True) + print(f" TF+rollout: mean {ro.mean():+.3f} std {ro.std():.3f} [{ro.min():+.3f},{ro.max():+.3f}]", flush=True) + print(f" paired diff (ro-tf): mean {(ro-tf).mean():+.3f} per-seed {np.round(ro-tf,3)}", flush=True) + + fig, (axL, axR) = plt.subplots(1, 2, figsize=(11, 5)) + for i in range(len(tf)): + axL.plot([0, 1], [tf[i], ro[i]], "-o", color="0.6", zorder=1) + axL.scatter(np.zeros_like(tf), tf, color="tab:red", zorder=2, label="TF-only") + axL.scatter(np.ones_like(ro), ro, color="tab:blue", zorder=2, label="TF+rollout") + axL.set_xticks([0, 1]); axL.set_xticklabels(["TF-only", "TF+rollout"]) + axL.set_xlim(-0.4, 1.4); axL.set_ylabel("autonomy raw (held-out gabo_p9)") + axL.set_title("paired by seed"); axL.legend(fontsize=8) + axR.boxplot([tf, ro], tick_labels=["TF-only", "TF+rollout"], showmeans=True, widths=0.5) + axR.scatter(np.ones_like(tf), tf, color="tab:red", alpha=0.7) + axR.scatter(np.full_like(ro, 2), ro, color="tab:blue", alpha=0.7) + axR.set_ylabel("autonomy raw") + axR.set_title(f"distribution ({len(tf)} seeds, lam_env={args.lam_env})") + fig.suptitle("Integrated TF + spectral/envelope rollout refinement: autonomous reconstruction") + fig.tight_layout() + figpath = os.path.join(out, "rollout_pipeline.png") + fig.savefig(figpath, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"\nSaved figure -> {figpath}", flush=True) + print(f"Saved CSV -> {os.path.join(out, 'rollout_pipeline.csv')}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/examples/sweep_lowpass.py b/examples/sweep_lowpass.py new file mode 100644 index 0000000..c8476fb --- /dev/null +++ b/examples/sweep_lowpass.py @@ -0,0 +1,97 @@ +""" +Sweep the drive low-pass timescale (ms) for ArneodoOuroboros and report the +R²-vs-smoothness tradeoff. Gathers dataloaders once, trains one model per cutoff. + +Run from the repo root: + python -m examples.sweep_lowpass --data-glob './data500/gabo_*' --out-dir ./arneodo_sweep \ + --lowpass-ms 1 2 3 4 5 --epochs 30 +""" + +import argparse +import glob +import os + +import numpy as np +import torch +from torch.optim import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau +from scipy.signal import welch + +from model.model import ArneodoOuroboros +from train.train import train, save_model +from train.eval import eval_model_error +from utils import sse +from examples.train_arneodo_big import gather_loaders + + +def beta_peak_freq(model, dls, dt): + """dominant frequency (Hz) of the learned beta drive on a test batch.""" + x, dxdt, _ = next(iter(dls["test"])) + x = x.cuda().float(); dxdt = dxdt.cuda().float() + with torch.no_grad(): + _, be, _, _ = model.get_funcs(x, dxdt, dt) + be = be[0].detach().cpu().numpy().squeeze() + f, P = welch(be - be.mean(), fs=1 / dt, nperseg=min(1024, len(be))) + P[0] = 0.0 + return float(f[np.argmax(P)]) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data-glob", required=True) + p.add_argument("--out-dir", default="./arneodo_sweep") + p.add_argument("--lowpass-ms", type=float, nargs="+", default=[1, 2, 3, 4, 5]) + p.add_argument("--epochs", type=int, default=30) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--max-vocs", type=int, default=100000) + p.add_argument("--context-len", type=float, default=0.25) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--n-layers", type=int, default=3) + p.add_argument("--d-state", type=int, default=4) + p.add_argument("--d-conv", type=int, default=4) + p.add_argument("--expand-factor", type=int, default=10) + p.add_argument("--seed", type=int, default=1234) + p.add_argument("--n-jobs", type=int, default=8) + args = p.parse_args() + + data_dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + assert data_dirs, f"no dirs matched {args.data_glob}" + dls, dt = gather_loaders( + data_dirs, args.max_vocs, args.context_len, args.batch_size, args.seed, args.n_jobs + ) + info = {"n layers": args.n_layers, "d state": args.d_state, + "d conv": args.d_conv, "expand factor": args.expand_factor} + + rows = [] + for lp in args.lowpass_ms: + run_dir = os.path.join(os.path.abspath(args.out_dir), f"lp_{lp:g}ms", "arneodo") + os.makedirs(run_dir, exist_ok=True) + model = ArneodoOuroboros( + d_data=1, n_layers=args.n_layers, d_state=args.d_state, d_conv=args.d_conv, + expand_factor=args.expand_factor, tau=dt, drive_lowpass_ms=lp, + ) + opt = Adam(model.parameters(), lr=args.lr) + sched = ReduceLROnPlateau(opt, factor=0.5, patience=5, min_lr=1e-10) + train(model, opt, loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), + loaders=dls, scheduler=sched, nEpochs=args.epochs, val_freq=1, runDir=run_dir, + dt=dt, vis_freq=0, smoothing=False, reg_weights=False, start_epoch=0, + save_freq=args.epochs, model_info=info) + model.eval() + with torch.no_grad(): + (tr, te), _, _ = eval_model_error(dls, model, dt=dt, comparison="test") + pf = beta_peak_freq(model, dls, dt) + save_model(model, opt, os.path.join(run_dir, f"checkpoint_{args.epochs}.tar"), + n_layers=args.n_layers, d_state=args.d_state, + d_conv=args.d_conv, expand_factor=args.expand_factor) + rows.append((lp, tr, te, pf)) + print(f">>> lowpass={lp:g}ms train R2={tr:.4f} test R2={te:.4f} beta peak={pf:.0f}Hz", flush=True) + + print("\n==== low-pass timescale sweep (cutoff ~ 1/(2*pi*ms)) ====") + print(" ms | ~cutoff Hz | train R2 | test R2 | beta-drive peak Hz") + for lp, tr, te, pf in rows: + print(f" {lp:>3g} | {1/(2*np.pi*lp/1e3):>9.0f} | {tr:>7.3f} | {te:>6.3f} | {pf:>6.0f}") + print("(reference: no low-pass -> test R2 ~0.987, beta peak ~2266 Hz)") + + +if __name__ == "__main__": + main() diff --git a/examples/train_poly_lowpass.py b/examples/train_poly_lowpass.py index da13a21..c8cbad7 100644 --- a/examples/train_poly_lowpass.py +++ b/examples/train_poly_lowpass.py @@ -60,9 +60,13 @@ def main(): p.add_argument("--n-kernels", type=int, default=15) p.add_argument("--lam", type=float, default=1.2, help="kernel-weight regularization base") p.add_argument("--drive-lowpass-ms", type=float, default=2.0) + p.add_argument("--keep-const", action="store_true", + help="add the constant (0,0) 'alpha' forcing term (low-passed with the other drives)") p.add_argument("--seed", type=int, default=1234) p.add_argument("--n-jobs", type=int, default=8) args = p.parse_args() + torch.manual_seed(args.seed) + np.random.seed(args.seed) run_dir = os.path.join(os.path.abspath(args.out_dir), "poly") os.makedirs(run_dir, exist_ok=True) @@ -76,7 +80,7 @@ def main(): activation=lambda x: x, lam=args.lam) model = Ouroboros(d_data=1, kernel=kernel, n_layers=args.n_layers, d_state=args.d_state, d_conv=args.d_conv, expand_factor=args.expand_factor, tau=dt, - drive_lowpass_ms=args.drive_lowpass_ms) + drive_lowpass_ms=args.drive_lowpass_ms, keep_const=args.keep_const) n_params = sum(q.numel() for q in model.parameters()) print(f"poly Ouroboros: n_kernels={args.n_kernels} d_state={args.d_state} expand={args.expand_factor} " f"lowpass={args.drive_lowpass_ms}ms lam={args.lam} -> {n_params} params; tau={dt:.2e}", flush=True) diff --git a/examples/train_poly_rollout.py b/examples/train_poly_rollout.py new file mode 100644 index 0000000..4177592 --- /dev/null +++ b/examples/train_poly_rollout.py @@ -0,0 +1,129 @@ +""" +Integrated trainer for the (low-pass) polynomial Ouroboros: teacher-forced training THEN the +phase-invariant spectral+envelope rollout-refinement phase (train.rollout_refine), in one run. + +This folds the autonomous-reconstruction objective into the training pipeline rather than running it +as a separate post-hoc fine-tune. Produces two checkpoints under : + - tf/checkpoint_.tar : teacher-forced only (reference / ablation) + - poly/checkpoint_.tar : TF + rollout-refined (final) + +Run from the repo root: + python -m examples.train_poly_rollout --data-glob 'data500/gabo_p[0-7]' --out-dir ./poly_ro_seed0 \ + --tf-epochs 30 --rollout-epochs 8 --lam-env 10 --drive-lowpass-ms 1.0 --d-state 4 \ + --context-len 0.1 --batch-size 8 --seed 0 +""" + +import argparse +import glob +import os + +import numpy as np +import torch +from torch.optim import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau + +from model.model import Ouroboros +from model.kernels import fullPolyModule +from train.train import train, save_model +from train.eval import eval_model_error +from train.rollout_refine import gather_windows, rollout_refine +from utils import sse +from examples.train_poly_lowpass import gather_loaders + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data-glob", required=True) + p.add_argument("--out-dir", required=True) + # teacher-forced phase + p.add_argument("--max-vocs", type=int, default=100000) + p.add_argument("--context-len", type=float, default=0.1) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--tf-epochs", type=int, default=30) + p.add_argument("--seg", type=int, default=10) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--n-layers", type=int, default=3) + p.add_argument("--d-state", type=int, default=4) + p.add_argument("--d-conv", type=int, default=4) + p.add_argument("--expand-factor", type=int, default=10) + p.add_argument("--n-kernels", type=int, default=15) + p.add_argument("--lam", type=float, default=1.068) + p.add_argument("--drive-lowpass-ms", type=float, default=1.0) + p.add_argument("--keep-const", action="store_true") + # rollout-refine phase + p.add_argument("--rollout-epochs", type=int, default=8) + p.add_argument("--rollout-hmin", type=int, default=768) + p.add_argument("--rollout-hmax", type=int, default=1500) + p.add_argument("--rollout-l-seg", type=int, default=1500) + p.add_argument("--rollout-windows", type=int, default=12) + p.add_argument("--rollout-batch", type=int, default=6) + p.add_argument("--rollout-lr", type=float, default=1e-4) + p.add_argument("--lam-spec", type=float, default=1.0) + p.add_argument("--lam-env", type=float, default=10.0) + p.add_argument("--lam-tf", type=float, default=1.0) + p.add_argument("--env-ms", type=float, default=2.0) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--n-jobs", type=int, default=8) + args = p.parse_args() + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + out_dir = os.path.abspath(args.out_dir) + tf_dir = os.path.join(out_dir, "tf") + poly_dir = os.path.join(out_dir, "poly") + os.makedirs(tf_dir, exist_ok=True) + os.makedirs(poly_dir, exist_ok=True) + data_dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + assert data_dirs, f"no dirs matched {args.data_glob}" + + dls, dt = gather_loaders(data_dirs, args.max_vocs, args.context_len, args.batch_size, + args.seed, args.n_jobs) + + kernel = fullPolyModule(nTerms=args.n_kernels, device="cuda", x_dim=1, z_dim=2, + activation=lambda x: x, lam=args.lam) + model = Ouroboros(d_data=1, kernel=kernel, n_layers=args.n_layers, d_state=args.d_state, + d_conv=args.d_conv, expand_factor=args.expand_factor, tau=dt, + drive_lowpass_ms=args.drive_lowpass_ms, keep_const=args.keep_const) + n_params = sum(q.numel() for q in model.parameters()) + print(f"poly Ouroboros: lowpass={args.drive_lowpass_ms}ms lam={args.lam} keep_const={args.keep_const} " + f"-> {n_params} params; tau={dt:.2e}", flush=True) + + # ---- teacher-forced phase ---- + opt = Adam(model.parameters(), lr=args.lr) + scheduler = ReduceLROnPlateau(opt, factor=0.5, patience=max(args.seg, 3), min_lr=1e-10) + model_info = {"n layers": args.n_layers, "d state": args.d_state, + "d conv": args.d_conv, "expand factor": args.expand_factor} + for start in range(0, args.tf_epochs, args.seg): + end = min(args.tf_epochs, start + args.seg) + train(model, opt, loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), + loaders=dls, scheduler=scheduler, nEpochs=end, val_freq=1, runDir=poly_dir, + dt=dt, vis_freq=0, smoothing=False, reg_weights=True, start_epoch=start, + save_freq=max(args.seg, 1), model_info=model_info) + model.eval() + with torch.no_grad(): + (tr, te), _, _ = eval_model_error(dls, model, dt=dt, comparison="test") + print(f"[TF done] tf_epochs={args.tf_epochs} train R2={tr:.4f} test R2={te:.4f}", flush=True) + save_model(model, opt, os.path.join(tf_dir, f"checkpoint_{args.tf_epochs}.tar"), + n_layers=args.n_layers, d_state=args.d_state, d_conv=args.d_conv, + expand_factor=args.expand_factor) + + # ---- rollout-refine phase ---- + X, _ = gather_windows(data_dirs, args.rollout_windows, args.rollout_l_seg, args.start_offset_ms) + print(f"rollout windows: {X.shape[0]} x {X.shape[1]} samples", flush=True) + _, ro_opt = rollout_refine(model, X, dt, epochs=args.rollout_epochs, hmin=args.rollout_hmin, + hmax=args.rollout_hmax, batch_size=args.rollout_batch, lr=args.rollout_lr, + lam_spec=args.lam_spec, lam_env=args.lam_env, lam_tf=args.lam_tf, + env_ms=args.env_ms) + model.eval() + with torch.no_grad(): + (tr2, te2), _, _ = eval_model_error(dls, model, dt=dt, comparison="test") + print(f"[refine done] test R2 {te:.4f} -> {te2:.4f}", flush=True) + save_model(model, ro_opt, os.path.join(poly_dir, f"checkpoint_{args.tf_epochs + args.rollout_epochs}.tar"), + n_layers=args.n_layers, d_state=args.d_state, d_conv=args.d_conv, + expand_factor=args.expand_factor) + print(f"DONE seed{args.seed}: tf R2={te:.4f} final R2={te2:.4f} ({tf_dir} | {poly_dir})", flush=True) + + +if __name__ == "__main__": + main() diff --git a/model/kernels.py b/model/kernels.py index 781763f..2ae1bb2 100644 --- a/model/kernels.py +++ b/model/kernels.py @@ -62,6 +62,9 @@ def __init__(self, nTerms, device, x_dim, z_dim, lam=0.01, activation=lambda x: self.weights = nn.Linear(self.n, (self.poly_dim + 1) ** 2).to(self.device) self.powers = torch.arange(0, self.poly_dim + 1, device=self.device) self.lam = lam + # if True, keep the constant (0,0) term (the "alpha"-like forcing y^0*ydot^0). + # The linear (1,0)/(0,1) terms are always zeroed (omega/gamma handle those). + self.keep_const = False def forward( self, x: torch.FloatTensor, z: torch.FloatTensor @@ -85,7 +88,8 @@ def forward( weights = weights.view(B, L, self.poly_dim + 1, self.poly_dim + 1) ### constant term - weights[:, :, 0, 0] = weights[:, :, 0, 0] * 0 + if not self.keep_const: + weights[:, :, 0, 0] = weights[:, :, 0, 0] * 0 ### y, ydot terms weights[:, :, 1, 0] = weights[:, :, 1, 0] * 0 weights[:, :, 0, 1] = weights[:, :, 0, 1] * 0 @@ -124,7 +128,8 @@ def forward_given_weights( B, L, d = x.shape weights = weights.view(B, L, self.poly_dim + 1, self.poly_dim + 1) ### constant term - weights[:, :, 0, 0] = weights[:, :, 0, 0] * 0 + if not self.keep_const: + weights[:, :, 0, 0] = weights[:, :, 0, 0] * 0 ### y, ydot terms weights[:, :, 1, 0] = weights[:, :, 1, 0] * 0 weights[:, :, 0, 1] = weights[:, :, 0, 1] * 0 @@ -163,7 +168,8 @@ def forward_given_weights_numpy( weights = np.reshape(weights, (B, L, self.poly_dim + 1, self.poly_dim + 1)) # constant term - weights[:, :, 0, 0] = weights[:, :, 0, 0] * 0 + if not self.keep_const: + weights[:, :, 0, 0] = weights[:, :, 0, 0] * 0 ### y, ydot terms weights[:, :, 1, 0] = weights[:, :, 1, 0] * 0 weights[:, :, 0, 1] = weights[:, :, 0, 1] * 0 diff --git a/model/model.py b/model/model.py index d8443c6..1b29c28 100644 --- a/model/model.py +++ b/model/model.py @@ -39,6 +39,7 @@ def __init__( tau: float = 1 / 10000, smooth_len: float = 0.001, drive_lowpass_ms: float = 0.0, + keep_const: bool = False, ): super().__init__() @@ -89,14 +90,29 @@ def __init__( self.drive_lowpass_ms = drive_lowpass_ms self.kernel = kernel self.kernel.tau = self.tau + # if True, ADD the constant (0,0) "alpha"-like forcing term (y^0*ydot^0) to the kernel + # instead of zeroing it; it is low-passed at drive_lowpass_ms like the other drives. + self.keep_const = keep_const + if keep_const: + self.kernel.keep_const = True + # zero-init the (0,0) output of the kernel weight head so the alpha forcing starts at + # 0 and is learned gently (otherwise the random constant disrupts early training). + with torch.no_grad(): + self.kernel.weights.weight[0].zero_() + self.kernel.weights.bias[0].zero_() self.names = [r"$\omega$", r"$\gamma$", "weighted kernels", "states"] - def _lowpass(self, x: torch.FloatTensor, dt: float) -> torch.FloatTensor: - """centered zero-phase Gaussian low-pass along time of a (B, L, C) control series.""" - sigma = (self.drive_lowpass_ms / 1e3) / dt # samples + def _lowpass(self, x: torch.FloatTensor, dt: float, lp_ms: float = None) -> torch.FloatTensor: + """centered zero-phase Gaussian low-pass along time of a (B, L, C) control series. + lp_ms overrides the timescale (defaults to self.drive_lowpass_ms). The kernel radius is + capped at L-1 so very slow (large-sigma) low-passes work on short segments (reduce to + ~a segment-wide average).""" + ms = self.drive_lowpass_ms if lp_ms is None else lp_ms + sigma = (ms / 1e3) / dt # samples if sigma <= 0: return x - radius = max(1, int(round(3 * sigma))) + L = x.shape[1] + radius = max(1, min(int(round(3 * sigma)), L - 1)) t = torch.arange(-radius, radius + 1, device=x.device, dtype=x.dtype) kern = torch.exp(-0.5 * (t / sigma) ** 2) kern = kern / kern.sum() diff --git a/train/eval.py b/train/eval.py index de7c4c5..0964e24 100644 --- a/train/eval.py +++ b/train/eval.py @@ -10,6 +10,7 @@ from torchdiffeq import odeint_adjoint from scipy.interpolate import make_interp_spline +from scipy.signal import welch """ tools for evaluating model performance. covers both regular evaluation and model integration @@ -437,6 +438,126 @@ def dz_hat(s, z): return x_gen +def autonomy_score( + model: torch.nn.Module, + segments: list, + dt: float, + fmax: float = 8000.0, + w_amp: float = 1.0, + w_pitch: float = 1.0, + diverge_score: float = -5.0, + method: str = "rk4", + rescale: bool = False, +) -> tuple: + """ + validation metric for AUTONOMOUS reconstruction quality (model-selection criterion). + + For each pre-windowed (sustained) vocalization segment, run the model's DETERMINISTIC + autonomous integration and score it against the target by phase-robust spectral (log-PSD) + correlation, minus log-ratio penalties on amplitude and pitch: + + score = spectral_corr(autonomous, target) + - w_amp * |log(std_auto / std_target)| + - w_pitch * |log(pitch_auto / pitch_target)| + + A divergent / collapsed rollout (non-finite or ~zero) gets `diverge_score`. Works for both + the polynomial (`integrate_poly_autonomous`) and Arneodo (`integrate_model_autonomous`) models. + + If `rescale=True`, the autonomous output's RMS is matched to the target's before scoring (the + deployed recipe -- amplitude is an arbitrary overall-scale gauge fixed at generation by + `generate_autonomous`). This zeroes `amp_pen` for bounded rollouts, so selection then turns on + the genuinely-constrained quantities (spectral shape + pitch + boundedness). A divergent rollout + is still detected on the RAW output and gets `diverge_score` (collapse is not rescaled away). + + returns + ----- + - mean score over segments + - per-segment scores (list) + - breakdown dict (mean spectral corr, amp penalty, pitch penalty, bounded fraction) + """ + fs = 1.0 / dt + + def _logpsd(x): + f, P = welch(x - np.mean(x), fs=fs, nperseg=min(1024, len(x))) + m = f <= fmax + return np.log(P[m] + 1e-20) + + def _peak(x): + f, P = welch(x - np.mean(x), fs=fs, nperseg=min(1024, len(x))) + P[0] = 0 + return float(f[np.argmax(P)]) + + is_poly = hasattr(model, "kernel") + scores, specs, amps, pits, bounded = [], [], [], [], [] + for seg in segments: + seg = np.asarray(seg, dtype=np.float64) + tgt = correct(seg) + if is_poly: + auto = integrate_poly_autonomous(model, seg, dt, method=method, noise_sd=0.0, + detrend=True, verbose=False) + else: + auto = integrate_model_autonomous(model, seg, dt, method=method, detrend=True, + verbose=False) + n = min(len(tgt), len(auto)) + tgt_n, auto_n = tgt[:n], auto[:n] + if (not np.isfinite(auto_n).all()) or np.nanstd(auto_n) < 1e-9: + scores.append(diverge_score) + bounded.append(0.0) + specs.append(np.nan); amps.append(np.nan); pits.append(np.nan) + continue + bounded.append(1.0) + if rescale: # gauge-fix amplitude to the target RMS (the deployed recipe) + auto_n = auto_n * (np.nanstd(tgt_n) / (np.nanstd(auto_n) + 1e-12)) + sc = float(np.corrcoef(_logpsd(tgt_n), _logpsd(auto_n))[0, 1]) + amp = abs(np.log((np.nanstd(auto_n) + 1e-12) / (np.nanstd(tgt_n) + 1e-12))) + pit = abs(np.log((_peak(auto_n) + 1e-9) / (_peak(tgt_n) + 1e-9))) + scores.append(sc - w_amp * amp - w_pitch * pit) + specs.append(sc); amps.append(amp); pits.append(pit) + + breakdown = { + "spec_corr": float(np.nanmean(specs)) if len(specs) else float("nan"), + "amp_pen": float(np.nanmean(amps)) if len(amps) else float("nan"), + "pitch_pen": float(np.nanmean(pits)) if len(pits) else float("nan"), + "bounded_frac": float(np.mean(bounded)) if len(bounded) else 0.0, + } + return float(np.mean(scores)), scores, breakdown + + +def generate_autonomous( + model: torch.nn.Module, + audio: np.ndarray, + dt: float, + rescale: bool = True, + ref_rms: float = None, + method: str = "rk4", + detrend: bool = True, + verbose: bool = False, +) -> np.ndarray: + """ + DEPLOYED autonomous generation: closed-loop integration + amplitude rescaling. + + Autonomous amplitude is a poorly-constrained, marginal direction (the transverse Floquet + exponent is left free by on-orbit teacher-forced fitting), so the raw free-running amplitude + decays/grows/varies by seed. The overall scale is an arbitrary audio-unit gauge, so we fix it + at generation: run the deterministic closed-loop rollout, then (if `rescale`) match its RMS to + a reference -- `ref_rms` if given, else the input window's own (detrended) RMS for reconstruction. + + Works for poly (`integrate_poly_autonomous`) and Arneodo (`integrate_model_autonomous`) models. + Returns the (rescaled) generated waveform; a divergent/collapsed rollout is returned unscaled. + """ + is_poly = hasattr(model, "kernel") + if is_poly: + auto = integrate_poly_autonomous(model, audio, dt, method=method, detrend=detrend, + noise_sd=0.0, verbose=verbose) + else: + auto = integrate_model_autonomous(model, audio, dt, method=method, detrend=detrend, + verbose=verbose) + if rescale and np.isfinite(auto).all() and np.nanstd(auto) > 1e-9: + target = ref_rms if ref_rms is not None else float(np.nanstd(correct(np.asarray(audio, dtype=np.float64)))) + auto = auto * (target / (np.nanstd(auto) + 1e-12)) + return auto + + def eval_model_error( dls: dict, model: torch.nn.Module, dt: float, comparison: str = "val" ) -> tuple[tuple[float, float], tuple[float, float], tuple[np.ndarray, np.ndarray]]: diff --git a/train/model_cv.py b/train/model_cv.py index 401ef85..9e52007 100644 --- a/train/model_cv.py +++ b/train/model_cv.py @@ -3,12 +3,13 @@ from model.model import Ouroboros, ArneodoOuroboros from utils import sse from visualization.model_vis import loss_plot -from train.eval import eval_model_error +from train.eval import eval_model_error, autonomy_score from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau import os import glob +import gc import pandas as pd import numpy as np import matplotlib.pyplot as plt @@ -31,6 +32,14 @@ def model_cv_lambdas( smooth_len: float = 0.001, model_path: str = "", save_freq: int = 5, + drive_lowpass_ms: float = 0.0, + n_seeds: int = 1, + selection: str = "r2", + val_vocs: list = None, + test_vocs: list = None, + keep_const: bool = False, + rescale_autonomy: bool = False, + lambdas: list = None, ) -> torch.nn.Module: """ This function trains models and cross-validates across regularization strengths. @@ -67,151 +76,109 @@ def model_cv_lambdas( "expand factor": expand_factor, } - min_lambda = 1.01 - max_lambda = 10 ** (4 / (2 * n_kernels)) - - lambdas = np.linspace(min_lambda, max_lambda, 7) - # - lam_train_cv_err = [] - lam_test_cv_err = [] - - lam_train_cv_sd = [] - lam_test_cv_sd = [] - - lam_train_cv_r2 = [] - lam_test_cv_r2 = [] - - for ii, lam in enumerate(lambdas): - print(f"Regularizing with lambda={lam}") - - kernel = fullPolyModule( - nTerms=n_kernels, - device="cuda", - x_dim=1, - z_dim=2, - activation=lambda x: x, - lam=lam, - ) - reg_weights = True - full_model_poly = Ouroboros( - d_data=1, - n_layers=n_layers, - d_state=d_state, - d_conv=d_conv, - expand_factor=expand_factor, - tau=tau, - smooth_len=smooth_len, - kernel=kernel, - ) - - full_opt_poly = Adam(full_model_poly.parameters(), lr=lr) - full_scheduler_poly = ReduceLROnPlateau( - full_opt_poly, factor=0.5, patience=max(n_epochs // 25, 2), min_lr=1e-10 - ) - model_path_full_poly = ( - model_path + f"/kernelborous_poly_end_to_end_lambda_{lam}" + # default: the standard 7-point lambda grid. Pass `lambdas=[x]` for a single fixed lambda + # (e.g. the seed-selection production recipe, where lambda is known to be irrelevant for + # autonomy and only the seed matters). + if lambdas is None: + min_lambda = 1.01 + max_lambda = 10 ** (4 / (2 * n_kernels)) + lambdas = np.linspace(min_lambda, max_lambda, 7) + else: + lambdas = np.asarray(lambdas, dtype=float) + + if selection == "autonomy" and not val_vocs: + raise ValueError( + "selection='autonomy' requires val_vocs (held-out vocalization segments)" ) - save_loc_poly = model_path_full_poly + f"/checkpoint_{n_epochs}.tar" - save_files = glob.glob(os.path.join(model_path_full_poly, "*.tar")) - start_epoch = 0 - if len(save_files) > 0: - full_model_poly, full_opt_poly, full_scheduler_poly, start_epoch = ( - load_model(model_path_full_poly) + # train n_seeds models per lambda; record the validation metric(s) for each. Resumable: + # a (lambda, seed) with a complete checkpoint is loaded instead of retrained. + records = [] + for lam in lambdas: + for seed in range(n_seeds): + torch.manual_seed(seed) + np.random.seed(seed) + kernel = fullPolyModule( + nTerms=n_kernels, device="cuda", x_dim=1, z_dim=2, + activation=lambda x: x, lam=lam, ) - - if start_epoch < n_epochs: - tl, vl, full_model_poly, full_opt_poly = train( - full_model_poly, - full_opt_poly, - loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), - loaders=dls, - scheduler=full_scheduler_poly, - nEpochs=n_epochs, - val_freq=1, - runDir=model_path_full_poly, - dt=dt, - vis_freq=max(n_epochs // 10, 1), - smoothing=False, - reg_weights=reg_weights, - start_epoch=start_epoch, - save_freq=save_freq, - model_info=model_info, + model = Ouroboros( + d_data=1, n_layers=n_layers, d_state=d_state, d_conv=d_conv, + expand_factor=expand_factor, tau=tau, smooth_len=smooth_len, + kernel=kernel, drive_lowpass_ms=drive_lowpass_ms, keep_const=keep_const, ) - - loss_plot(tl, vl, save_loc=model_path_full_poly, show=False) - - save_model( - full_model_poly, - full_opt_poly, - save_loc_poly, - n_layers=n_layers, - d_state=d_state, - expand_factor=expand_factor, - d_conv=d_conv, - ) - - full_model_poly.eval() - with torch.no_grad(): - (train_mu, test_mu), (train_sd, test_sd), (train_r2, test_r2) = ( - eval_model_error(dls, full_model_poly, dt=dt) + opt = Adam(model.parameters(), lr=lr) + sched = ReduceLROnPlateau( + opt, factor=0.5, patience=max(n_epochs // 25, 2), min_lr=1e-10 ) - lam_train_cv_err.append(train_mu) - lam_test_cv_err.append(test_mu) - - lam_train_cv_sd.append(train_sd) - lam_test_cv_sd.append(test_sd) - - lam_train_cv_r2.append(train_r2) - lam_test_cv_r2.append(test_r2) - - splits = ["train"] * len(lam_train_cv_err) + ["val"] * len(lam_test_cv_err) - lambdas_stacked = np.round(np.hstack([lambdas, lambdas]), 3) - errs = np.hstack([lam_train_cv_err, lam_test_cv_err]) - df = pd.DataFrame({"lam": lambdas_stacked, "split": splits, "R2": errs}) - - min_err_ind = np.argmax(lam_test_cv_err) # argmax, since 'err' is actually r2 - print(f"best R2 alpha for {n_kernels} kernels: {lambdas[min_err_ind]}") - ax = plt.gca() - - sns.boxplot( - data=df, - x="lam", - y="R2", - hue="split", - hue_order=["train", "test"], - ax=ax, - gap=0.1, - ) - - ax.set_xlabel("Polynomial degree penalty") - ax.set_ylabel(r"$R^2$") - ylim = ax.get_ylim() - ylim = (min(ylim[0], 0), max(ylim[-1], 1.01)) - ax.set_ylim(ylim) - ax.legend() - plt.savefig( - os.path.join(model_path, "train_test_error_kernel_poly_nkernels_30.svg") - ) + run_dir = os.path.join(model_path, f"poly_lam{lam:.3f}_seed{seed}") + save_loc = os.path.join(run_dir, f"checkpoint_{n_epochs}.tar") + + start_epoch = 0 + if glob.glob(os.path.join(run_dir, "*.tar")): + model, opt, sched, start_epoch = load_model(run_dir) # resume + model.kernel.lam = float(lam) # load_model hardcodes lam=1; restore it + if start_epoch < n_epochs: + train( + model, opt, + loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), + loaders=dls, scheduler=sched, nEpochs=n_epochs, val_freq=1, + runDir=run_dir, dt=dt, vis_freq=0, smoothing=False, + reg_weights=True, start_epoch=start_epoch, save_freq=save_freq, + model_info=model_info, + ) + save_model(model, opt, save_loc, n_layers=n_layers, d_state=d_state, + expand_factor=expand_factor, d_conv=d_conv) + + model.eval() + with torch.no_grad(): + (_, val_r2), _, _ = eval_model_error(dls, model, dt=dt, comparison="val") + val_auto = np.nan + if selection == "autonomy": + val_auto, _, bd = autonomy_score(model, val_vocs, dt, rescale=rescale_autonomy) + print(f"lam={lam:.3f} seed={seed}: val R2={val_r2:.4f} val autonomy={val_auto:+.4f} " + f"(spec={bd['spec_corr']:.2f} amp_pen={bd['amp_pen']:.2f} " + f"pitch_pen={bd['pitch_pen']:.2f} bounded={bd['bounded_frac']:.2f})", flush=True) + else: + print(f"lam={lam:.3f} seed={seed}: val R2={val_r2:.4f}", flush=True) + records.append({"lambda": lam, "seed": seed, "val_r2": val_r2, + "val_autonomy": val_auto, "ckpt": run_dir}) + del model, opt, kernel + gc.collect() + torch.cuda.empty_cache() + + df = pd.DataFrame(records) + metric = "val_autonomy" if selection == "autonomy" else "val_r2" + per_lam = df.groupby("lambda")[metric].agg(["mean", "std"]) + best_lambda = float(per_lam["mean"].idxmax()) + print(f"\n=== lambda selection by {selection} (mean over {n_seeds} seed(s)) ===", flush=True) + for lam, row in per_lam.iterrows(): + mark = " <-- selected" if abs(lam - best_lambda) < 1e-9 else "" + sd = 0.0 if np.isnan(row["std"]) else row["std"] + print(f" lambda={lam:.3f}: {metric}={row['mean']:+.4f} +- {sd:.4f}{mark}", flush=True) + df.to_csv(os.path.join(model_path, "lambda_seed_cv.csv"), index=False) + + plt.figure() + plt.errorbar(per_lam.index, per_lam["mean"], per_lam["std"].fillna(0.0), marker="o") + plt.axvline(best_lambda, ls="--", color="0.6") + plt.xlabel(r"kernel-weight $\lambda$") + plt.ylabel(f"validation {metric}") + plt.title(f"lambda selection by {selection} ({n_seeds} seeds)") + plt.savefig(os.path.join(model_path, f"lambda_selection_{selection}.svg")) plt.close() - model_path_best = ( - model_path + f"/kernelborous_poly_end_to_end_lambda_{lambdas[min_err_ind]}" - ) - - full_model_poly, full_opt_poly, full_scheduler_poly, _ = load_model(model_path_best) - full_model_poly.eval() + # best model = best seed at the selected lambda (by the same validation metric) + cand = df[np.isclose(df["lambda"], best_lambda)].sort_values(metric) + best_model, _, _, _ = load_model(cand.iloc[-1]["ckpt"]) + best_model.eval() with torch.no_grad(): - (train_mu, test_mu), (train_sd, test_sd), (train_r2, test_r2) = ( - eval_model_error(dls, full_model_poly, dt=dt, comparison="test") - ) - - data_df = pd.DataFrame( - {"lambdas": lambdas, "train MSE": lam_train_cv_err, "test MSE": lam_test_cv_err} - ) - data_df.to_csv(os.path.join(model_path, "cv_errs.csv")) - - return full_model_poly + (_, test_r2), _, _ = eval_model_error(dls, best_model, dt=dt, comparison="test") + test_auto = np.nan + if selection == "autonomy" and test_vocs: + test_auto, _, _ = autonomy_score(best_model, test_vocs, dt, rescale=rescale_autonomy) + print(f"BEST lambda={best_lambda:.3f}: test R2={test_r2:.4f} test autonomy={test_auto:+.4f}", + flush=True) + return best_model def train_arneodo( diff --git a/train/rollout_refine.py b/train/rollout_refine.py new file mode 100644 index 0000000..451de84 --- /dev/null +++ b/train/rollout_refine.py @@ -0,0 +1,173 @@ +""" +Rollout refinement objective for the polynomial Ouroboros -- a reusable training-pipeline component. + +Phase-invariant objective on the soft-saturated differentiable RK4 AUTONOMOUS rollout, so a +phase-drifted-but-correct-content free run is scored well and the optimizer fixes frequency content ++ loudness instead of collapsing amplitude (the failure mode of pointwise rollout MSE on a drifting +oscillator): + + L = lam_spec * L_mrstft + lam_env * L_env + lam_tf * L_tf + + - L_mrstft : multi-resolution STFT MAGNITUDE loss (spectral convergence + log-magnitude L1). + Magnitude discards phase -> tolerant of carrier drift; penalizes a wrong pitch/harmonic + stack that persists across frames. + - L_env : Gaussian-low-passed |x| envelope L1 (per-voc normalized). The envelope IS the + instantaneous amplitude, so REQUIRING IT TO MATCH pins the autonomous loudness/scale -- + but it must be weighted up (lam_env >> lam_spec) or the spectral term swamps it and the + global amplitude drifts free. lam_env~10 fixes amplitude across seeds without regressing + already-good models (validated 2026-05-28). + - L_tf : original one-step teacher-forced anchor (preserve the well-fit dynamics / R^2). + +Backprops through the rollout with a curriculum on the horizon; because the loss is phase-invariant +the horizon can be long. See examples/finetune_rollout_spectral_poly.py (single-model entry) and +examples/train_poly_rollout.py (integrated TF + refine trainer). +""" + +import glob +import os + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.io import wavfile + +from utils import deriv_approx_dy, deriv_approx_d2y + +BX, BXP = 0.5, 1.0 # soft-saturation bounds (>> data/limit-cycle scale; only tame divergence) +DEFAULT_CONFIGS = ((256, 64), (512, 128), (1024, 256)) + + +def gather_windows(dirs, n, L, start_off_ms): + """held-out sustained vocalization windows of length L, starting start_off_ms after onset.""" + segs, sr = [], None + for d in dirs: + for wav in sorted(glob.glob(os.path.join(d, "*.wav"))): + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + on = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt")))[0][0] + s = int(on * sr) + int(start_off_ms / 1e3 * sr) + seg = af[s:s + L] + if len(seg) == L: + segs.append(seg) + if len(segs) >= n: + return np.stack(segs)[:, :, None], sr + return np.stack(segs)[:, :, None], sr + + +def stft_mag(x, n_fft, hop): + win = torch.hann_window(n_fft, device=x.device, dtype=x.dtype) + S = torch.stft(x, n_fft=n_fft, hop_length=hop, win_length=n_fft, window=win, + center=True, return_complex=True) + return S.abs() # (B, F, T) + + +def mrstft_loss(xg, tgt, configs=DEFAULT_CONFIGS, eps=1e-5): + """multi-resolution STFT magnitude loss: spectral convergence + log-magnitude L1.""" + total = 0.0 + for n_fft, hop in configs: + A = stft_mag(xg, n_fft, hop) + G = stft_mag(tgt, n_fft, hop) + sc = torch.norm(G - A, dim=(-2, -1)) / (torch.norm(G, dim=(-2, -1)) + 1e-8) + logm = (torch.log(G + eps) - torch.log(A + eps)).abs().mean(dim=(-2, -1)) + total = total + (sc + logm).mean() + return total / len(configs) + + +def gaussian_envelope(x, dt, env_ms): + """Gaussian low-pass of |x| (removes the carrier, keeps the loudness modulation). x: (B,H).""" + sigma = (env_ms / 1e3) / dt + radius = max(1, int(round(3 * sigma))) + t = torch.arange(-radius, radius + 1, device=x.device, dtype=x.dtype) + k = torch.exp(-0.5 * (t / sigma) ** 2) + k = k / k.sum() + xr = F.pad(x.abs().unsqueeze(1), (radius, radius), mode="reflect") + return F.conv1d(xr, k.view(1, 1, -1)).squeeze(1) # (B,H) + + +def env_loss(xg, tgt, dt, env_ms, eps=1e-6): + ea = gaussian_envelope(xg, dt, env_ms) + eg = gaussian_envelope(tgt, dt, env_ms) + return ((ea - eg).abs().mean(dim=1) / (eg.mean(dim=1) + eps)).mean() + + +def rollout_refine(model, X, dt, *, epochs=8, hmin=768, hmax=1500, batch_size=6, lr=1e-4, + lam_spec=1.0, lam_env=10.0, lam_tf=1.0, clip=5.0, env_ms=2.0, + configs=DEFAULT_CONFIGS, device="cuda", verbose=True): + """ + In-place rollout refinement of a trained polynomial Ouroboros. + + X: numpy (N, L, 1) sustained vocalization windows (L >= hmax). Modifies `model` in place. + Returns (history, opt): per-epoch loss history (list of dicts) and the Adam optimizer (so the + caller can persist it via train.train.save_model). + """ + assert hmax <= X.shape[1], "hmax must be <= window length L" + model.train() + D1 = deriv_approx_dy(X) + D2 = deriv_approx_d2y(X) + Xt = torch.tensor(X, dtype=torch.float32, device=device) + Dt = torch.tensor(D1, dtype=torch.float32, device=device) + D2t = torch.tensor(D2, dtype=torch.float32, device=device) + var_d2 = float(D2t.var()) + P = model.kernel.poly_dim + 1 + powers = torch.arange(P, device=device) + opt = torch.optim.Adam(model.parameters(), lr=lr) + N = Xt.shape[0] + H_sched = np.unique(np.round(np.geomspace(hmin, hmax, epochs)).astype(int)) + if verbose: + print(f"rollout refine: {N} windows L={X.shape[1]} H {hmin}->{hmax} | " + f"lam tf={lam_tf} spec={lam_spec} env={lam_env} env_ms={env_ms} epochs={epochs}", flush=True) + + history = [] + for epoch in range(epochs): + H = int(H_sched[min(epoch, len(H_sched) - 1)]) + perm = torch.randperm(N) + tot = {"spec": 0.0, "env": 0.0, "tf": 0.0} + nb = 0 + for i in range(0, N, batch_size): + idx = perm[i:i + batch_size] + x, dxd, d2b = Xt[idx], Dt[idx], D2t[idx] + z2 = (model.tau / dt) * dxd + omega, gamma, wk, weights, _ = model.get_funcs(x, dxd.clone(), dt) # differentiable + tf = -(omega ** 2) * x - gamma * z2 - wk + L_tf = ((tf - d2b) ** 2).mean() / var_d2 + + om, ga, w = omega[:, :, 0], gamma[:, :, 0], weights # w: (B,L,P,P) + + def f(xx, vv, k): + xpw = xx.unsqueeze(1) ** powers + xvw = vv.unsqueeze(1) ** powers + kern = torch.einsum("bp,bk,bpk->b", xpw, xvw, w[:, k]) + return vv, -(om[:, k] ** 2) * xx - ga[:, k] * vv - kern + + xc = x[:, 0, 0].detach() + xp = z2[:, 0, 0].detach() + xs = [xc] + for k in range(H - 1): + k1x, k1v = f(xc, xp, k) + k2x, k2v = f(xc + 0.5 * k1x, xp + 0.5 * k1v, k) + k3x, k3v = f(xc + 0.5 * k2x, xp + 0.5 * k2v, k) + k4x, k4v = f(xc + k3x, xp + k3v, k) + xc = xc + (k1x + 2 * k2x + 2 * k3x + k4x) / 6 + xp = xp + (k1v + 2 * k2v + 2 * k3v + k4v) / 6 + xc = BX * torch.tanh(xc / BX) + xp = BXP * torch.tanh(xp / BXP) + xs.append(xc) + xg = torch.stack(xs, dim=1) # (B,H) autonomous + tgt = x[:, :H, 0] # (B,H) target, same time span + + L_spec = mrstft_loss(xg, tgt, configs) + L_env = env_loss(xg, tgt, dt, env_ms) + loss = lam_spec * L_spec + lam_env * L_env + lam_tf * L_tf + opt.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), clip) + opt.step() + tot["spec"] += float(L_spec); tot["env"] += float(L_env); tot["tf"] += float(L_tf) + nb += 1 + rec = {"epoch": epoch + 1, "H": H, "spec": tot["spec"] / nb, + "env": tot["env"] / nb, "tf": tot["tf"] / nb} + history.append(rec) + if verbose: + print(f"[refine ep {epoch + 1}/{epochs} H={H}] spec={rec['spec']:.4f} " + f"env={rec['env']:.4f} tf={rec['tf']:.4f}", flush=True) + return history, opt diff --git a/train/train.py b/train/train.py index 524a46c..2f2482b 100644 --- a/train/train.py +++ b/train/train.py @@ -68,6 +68,7 @@ def save_model( "expand_factor": expand_factor, "parameterization": parameterization, "drive_lowpass_ms": getattr(model, "drive_lowpass_ms", 0.0), + "keep_const": getattr(model, "keep_const", False), } try: sd["n_kernel"] = model.kernel.nTerms @@ -150,6 +151,7 @@ def load_model( smooth_len=sd["smooth_len"], kernel=kernel, drive_lowpass_ms=sd.get("drive_lowpass_ms", 0.0), + keep_const=sd.get("keep_const", False), ) except: print("no kernel in savefile!") From f9eb1fd0288f8dea66920ddb3c9df48c564a4963 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Thu, 28 May 2026 21:00:39 -0400 Subject: [PATCH 08/15] gitignore: ignore pipeline verify/full run output dirs Co-Authored-By: Claude Opus 4.7 --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 8aaee63..79eedf2 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,5 @@ poly_ro_pipeline/ poly_ro_*/ poly_ftspec_env10_*/ poly_attr_*/ +poly_pipeline_verify/ +poly_pipeline_full/ From aa5cdf52c29b4cc48fc496188a8df3826008366b Mon Sep 17 00:00:00 2001 From: John Pearson Date: Fri, 29 May 2026 04:22:16 -0400 Subject: [PATCH 09/15] run_lambda_pipeline: record the auto-selected lambda in the manifest when sweeping When --lam<=0 sweeps the grid, derive selected_lambda from the CV results (argmax mean validation autonomy) instead of writing the <=0 sweep sentinel. Co-Authored-By: Claude Opus 4.7 --- examples/run_lambda_pipeline.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/run_lambda_pipeline.py b/examples/run_lambda_pipeline.py index b634e14..2695e4c 100644 --- a/examples/run_lambda_pipeline.py +++ b/examples/run_lambda_pipeline.py @@ -25,6 +25,7 @@ import os import numpy as np +import pandas as pd from scipy.io import wavfile from data.load_data import get_segmented_audio @@ -110,11 +111,18 @@ def main(): # DEPLOYED generation: rescaled autonomous reconstruction of the held-out test vocs. # (Amplitude is a free gauge fixed at generation; selection above already used rescaled autonomy.) if test_vocs: + # actual selected lambda: argmax of mean validation autonomy (= args.lam when fixed; the + # auto-selected grid value when swept, rather than the <=0 sweep sentinel) + sel_lambda = float(args.lam) + csv_path = os.path.join(out_dir, "lambda_seed_cv.csv") + if args.lam <= 0 and os.path.isfile(csv_path): + dfc = pd.read_csv(csv_path) + sel_lambda = float(dfc.groupby("lambda")["val_autonomy"].mean().idxmax()) score, _, bd = autonomy_score(best_model, test_vocs, dt, rescale=True) recon = generate_autonomous(best_model, test_vocs[0], dt, rescale=True) wav = (recon / (np.abs(recon).max() + 1e-12) * 0.95 * 32767).astype(np.int16) wavfile.write(os.path.join(out_dir, "selected_autonomous_recon.wav"), int(round(1 / dt)), wav) - manifest = {"selected_lambda": float(args.lam), "n_seeds": int(args.n_seeds), + manifest = {"selected_lambda": sel_lambda, "n_seeds": int(args.n_seeds), "keep_const": bool(args.keep_const), "drive_lowpass_ms": float(args.drive_lowpass_ms), "rescaled_test_autonomy": score, "spec_corr": bd["spec_corr"], "pitch_pen": bd["pitch_pen"], "bounded_frac": bd["bounded_frac"]} From db6200e7f0c8c0c4b9017cb9c0482f62e77dfdc3 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Fri, 29 May 2026 10:18:41 -0400 Subject: [PATCH 10/15] docs: fold full-scale grid run + seed-culling findings into autonomous_amplitude writeup - Validation at scale: 7 lambda x {3,8} seeds x 50ep on the 500-voc set re-confirms seed >> lambda with the production rescaled metric (per-lambda mean +0.548 +-0.0008; seed spread 0.51->0.59); deployed held-out test autonomy +0.63-0.65 (best-of-3/8), bounded, right-pitch. - Seed culling: early-checkpoint rescaled autonomy predicts the final seed ranking. Epoch 10 (20% budget) is too early (rho 0.48, misses the winner); by epoch 20 (40% budget) top-1/top-2 are locked (rho 0.86). Recipe: train all seeds to ~40% budget, keep top 1-2, finish only those. - Adds examples/seed_cull_test.py. Co-Authored-By: Claude Opus 4.7 --- .gitignore | 1 + docs/autonomous_amplitude.md | 30 ++++++++++++ examples/seed_cull_test.py | 95 ++++++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 examples/seed_cull_test.py diff --git a/.gitignore b/.gitignore index 79eedf2..c62b631 100644 --- a/.gitignore +++ b/.gitignore @@ -184,3 +184,4 @@ poly_ftspec_env10_*/ poly_attr_*/ poly_pipeline_verify/ poly_pipeline_full/ +poly_seedcull/ diff --git a/docs/autonomous_amplitude.md b/docs/autonomous_amplitude.md index 691b288..67d1d83 100644 --- a/docs/autonomous_amplitude.md +++ b/docs/autonomous_amplitude.md @@ -94,6 +94,36 @@ python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' --out-dir . --n-epochs 50 --n-seeds 5 --lam 1.068 --drive-lowpass-ms 1.0 --d-state 4 ``` +## Validation at scale, and seed culling + +**Full grid run** (`examples/run_lambda_pipeline.py`, 7 λ × {3, 8} seeds × 50 ep, the 500-voc `data500` +set, 1 ms low-pass, `keep_const`) re-confirms the diagnosis on the full dataset with the production +*rescaled* metric. Rescaled validation autonomy is **flat across λ** (per-λ mean +0.548 ± 0.0008) and +**seed-structured** (seed spread 0.51 → 0.59); the across-seed std (≈ 0.04) is ~20× the across-λ +variation (≈ 0.002), and `amp_pen = 0.00` everywhere (rescaling working). The pipeline selects the best +seed and the deployed (rescaled) reconstruction reaches **test autonomy +0.63–0.65** on the held-out +shard (best-of-3 +0.626, best-of-8 +0.653), spectral corr ≈ 0.68, pitch within ~5%, fully bounded — +i.e. more seeds → a better best, as expected when the seed is the lever. + +**Seed culling — can we pick the winner early?** Since training all seeds fully is the cost, we asked +whether an *early* checkpoint's rescaled validation autonomy predicts the final seed ranking +(`examples/seed_cull_test.py`, 8 seeds at λ = 1.068). Spearman vs the final (epoch-50) ranking: + +| early epoch | Spearman ρ | top-1 hit | top-2 overlap | +|---|---|---|---| +| 10 (20% budget) | 0.48 | ✗ | 1/2 | +| 20 (40% budget) | 0.86 | ✓ | 2/2 | +| 30 | 0.76 | ✓ | 2/2 | +| 40 | 0.98 | ✓ | 2/2 | + +**Epoch 10 is too early** — the epoch-10 leader finished 6th of 8, and the eventual winner was only 2nd +at epoch 10. **By epoch 20 the top-1 and top-2 seeds are already correct** (the mid-pack keeps +reshuffling, so the *full* ranking only settles by ~epoch 40). Practical schedule: **train all seeds to +~40% of the budget, keep the top 1–2 by rescaled validation autonomy, and finish only those** — roughly +halving the seed-search cost. (Caveat: one dataset, one λ, 8 seeds; the 40%-budget threshold is +config-specific. Note teacher-forced R² is useless for this — it is seed-invariant; only the autonomous +rollout discriminates, and there is no cheaper on-orbit surrogate, since amplitude/shape live off-orbit.) + ## Open direction A genuine fix would place an **attracting limit cycle at the data amplitude** — i.e. make the data diff --git a/examples/seed_cull_test.py b/examples/seed_cull_test.py new file mode 100644 index 0000000..7349776 --- /dev/null +++ b/examples/seed_cull_test.py @@ -0,0 +1,95 @@ +""" +Does EARLY-checkpoint (rescaled) autonomy predict the FINAL seed ranking? If so, the seed search can +train many seeds briefly, cull to the top, and only finish the winner. + +For each seed dir poly_lam_seed/ with intermediate checkpoints, scores rescaled validation +autonomy (train.eval.autonomy_score(rescale=True) -- the deployed selection metric) at each checkpoint +epoch, then reports, for each early epoch k vs the final epoch: + - Spearman rank correlation across seeds (does early ranking match final ranking?) + - top-1 hit (does the early-best seed == final-best seed?) and top-2 overlap. + + python -m examples.seed_cull_test --run-dir poly_seedcull --lam 1.068 --n-seeds 8 \ + --epochs 10 20 30 40 50 --val-dir data500/gabo_p8 +""" + +import argparse +import glob +import os +import tempfile + +import numpy as np +from scipy.io import wavfile +from scipy.stats import spearmanr + +from train.train import load_model +from train.eval import autonomy_score + + +def load_voc_windows(data_dir, n, off_ms, L): + segs = [] + for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n]: + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + on = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt")))[0][0] + s = int(on * sr) + int(off_ms / 1e3 * sr) + seg = af[s:s + L] + if len(seg) == L: + segs.append(seg) + return segs, sr + + +def load_ckpt(seed_dir, epoch): + """load a SPECIFIC checkpoint (load_model takes the highest in a dir, so isolate via a temp dir).""" + src = os.path.abspath(os.path.join(seed_dir, f"checkpoint_{epoch}.tar")) + td = tempfile.mkdtemp() + os.symlink(src, os.path.join(td, f"checkpoint_{epoch}.tar")) + model, _, _, _ = load_model(td) + return model + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--run-dir", default="poly_seedcull") + p.add_argument("--lam", type=float, default=1.068) + p.add_argument("--n-seeds", type=int, default=8) + p.add_argument("--epochs", type=int, nargs="+", default=[10, 20, 30, 40, 50]) + p.add_argument("--val-dir", default="data500/gabo_p8") + p.add_argument("--n-vocs", type=int, default=5) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--auto-n", type=int, default=3000) + args = p.parse_args() + + segs, sr = load_voc_windows(args.val_dir, args.n_vocs, args.start_offset_ms, args.auto_n) + print(f"{len(segs)} val vocs from {args.val_dir}; seeds 0..{args.n_seeds-1} at lam={args.lam:.3f}\n", flush=True) + + table = {} # epoch -> np.array of rescaled autonomy per seed + for ep in args.epochs: + row = [] + for s in range(args.n_seeds): + d = os.path.join(args.run_dir, f"poly_lam{args.lam:.3f}_seed{s}") + ck = os.path.join(d, f"checkpoint_{ep}.tar") + if not os.path.isfile(ck): + row.append(np.nan); continue + model = load_ckpt(d, ep); model.eval() + sc, _, _ = autonomy_score(model, segs, model.tau, rescale=True) + row.append(sc) + table[ep] = np.array(row) + print(f"epoch {ep:>3d}: " + " ".join(f"s{i}={v:+.3f}" for i, v in enumerate(row)), flush=True) + + final_ep = args.epochs[-1] + final = table[final_ep] + order_final = np.argsort(-final) + print(f"\nfinal (epoch {final_ep}) seed ranking (best->worst): {list(order_final)}", flush=True) + print(f"\n{'early ep':>8s} {'Spearman rho':>13s} {'top-1 hit':>10s} {'top-2 overlap':>14s}", flush=True) + for ep in args.epochs[:-1]: + early = table[ep] + ok = np.isfinite(early) & np.isfinite(final) + rho, _ = spearmanr(early[ok], final[ok]) + order_early = np.argsort(-np.where(np.isfinite(early), early, -np.inf)) + top1 = order_early[0] == order_final[0] + top2 = len(set(order_early[:2]) & set(order_final[:2])) + print(f"{ep:>8d} {rho:>13.3f} {str(bool(top1)):>10s} {f'{top2}/2':>14s}", flush=True) + + +if __name__ == "__main__": + main() From 7ae6f283f2ae3c35d3a1c6c2e872a60bae7da8a1 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Fri, 29 May 2026 10:56:52 -0400 Subject: [PATCH 11/15] Add seed culling as a pipeline option (train all briefly, finish only the top seeds) model_cv_lambdas(cull_frac, cull_keep): when cull_frac>0 and selecting by autonomy, train every (lambda,seed) run to cull_frac*n_epochs, rank by rescaled validation autonomy, and finish only the top cull_keep; select the best finished run. The autonomy ranking settles well before R^2 plateaus (epoch ~40% in the 8-seed test: Spearman 0.86, top-1/top-2 correct; epoch 20% is too early), so this roughly halves the seed search. cull_frac=0 (default) preserves the existing full-train behavior. Refactors the per-run train/score into helpers; exposes the selected lambda on the returned model. run_lambda_pipeline: --cull-frac/--cull-keep flags; manifest uses the authoritative selected lambda. Documents the option in docs/autonomous_amplitude.md. Co-Authored-By: Claude Opus 4.7 --- docs/autonomous_amplitude.md | 10 +- examples/run_lambda_pipeline.py | 24 +++-- train/model_cv.py | 183 +++++++++++++++++++------------- 3 files changed, 135 insertions(+), 82 deletions(-) diff --git a/docs/autonomous_amplitude.md b/docs/autonomous_amplitude.md index 67d1d83..3a54af5 100644 --- a/docs/autonomous_amplitude.md +++ b/docs/autonomous_amplitude.md @@ -88,10 +88,15 @@ Both are folded into the pipeline: keep_const=True, lambdas=[λ])` — multi-seed selection by rescaled validation autonomy. - `examples/run_lambda_pipeline.py` — end-to-end: file-level holdout, seed selection, and a rescaled autonomous reconstruction (`selected_autonomous_recon.wav` + `selected_model.json`). +- **Seed culling** (`model_cv_lambdas(cull_frac=…, cull_keep=…)`, exposed as `--cull-frac/--cull-keep`): + train all seeds to `cull_frac` of the budget, finish only the top `cull_keep` by rescaled validation + autonomy — see below. ``` python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' --out-dir ./poly_pipeline \ --n-epochs 50 --n-seeds 5 --lam 1.068 --drive-lowpass-ms 1.0 --d-state 4 +# with seed culling — train 8 seeds, finish the best 2 after 40% of the budget: +python -m examples.run_lambda_pipeline --n-seeds 8 --cull-frac 0.4 --cull-keep 2 ``` ## Validation at scale, and seed culling @@ -120,8 +125,9 @@ whether an *early* checkpoint's rescaled validation autonomy predicts the final at epoch 10. **By epoch 20 the top-1 and top-2 seeds are already correct** (the mid-pack keeps reshuffling, so the *full* ranking only settles by ~epoch 40). Practical schedule: **train all seeds to ~40% of the budget, keep the top 1–2 by rescaled validation autonomy, and finish only those** — roughly -halving the seed-search cost. (Caveat: one dataset, one λ, 8 seeds; the 40%-budget threshold is -config-specific. Note teacher-forced R² is useless for this — it is seed-invariant; only the autonomous +halving the seed-search cost. This is built into the pipeline (`--cull-frac 0.4 --cull-keep 2`, i.e. +`model_cv_lambdas(cull_frac=, cull_keep=)`). (Caveat: one dataset, one λ, 8 seeds; the 40%-budget +threshold is config-specific. Note teacher-forced R² is useless for this — it is seed-invariant; only the autonomous rollout discriminates, and there is no cheaper on-orbit surrogate, since amplitude/shape live off-orbit.) ## Open direction diff --git a/examples/run_lambda_pipeline.py b/examples/run_lambda_pipeline.py index 2695e4c..5fa2200 100644 --- a/examples/run_lambda_pipeline.py +++ b/examples/run_lambda_pipeline.py @@ -12,11 +12,16 @@ - validation autonomy vocs from a held-out shard (seed selection), - test autonomy vocs from another held-out shard (final report + a rescaled reconstruction wav). -Pass --lam <=0 to instead sweep the standard 7-point lambda grid. +Pass --lam <=0 to instead sweep the standard 7-point lambda grid. Pass --cull-frac (e.g. 0.4) to +enable SEED CULLING: train all seeds to that fraction of the budget, then finish only the top +--cull-keep by rescaled validation autonomy (~halves the seed-search cost; see docs). Run from the repo root: python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' --out-dir ./poly_pipeline \ --n-epochs 50 --n-seeds 5 --lam 1.068 --drive-lowpass-ms 1.0 --d-state 4 + # with seed culling (train 8 seeds, finish the best 2 after 40% of epochs): + python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' --n-seeds 8 \ + --cull-frac 0.4 --cull-keep 2 """ import argparse @@ -25,7 +30,6 @@ import os import numpy as np -import pandas as pd from scipy.io import wavfile from data.load_data import get_segmented_audio @@ -54,6 +58,12 @@ def main(): p.add_argument("--out-dir", default="poly_pipeline") p.add_argument("--n-epochs", type=int, default=50) p.add_argument("--n-seeds", type=int, default=5) + p.add_argument("--cull-frac", type=float, default=0.0, + help="seed culling: train all runs to this fraction of n-epochs, then finish only " + "the top --cull-keep by rescaled validation autonomy (0 = train all fully). " + "~0.4 is a good default; the autonomy ranking settles by ~40%% of the budget.") + p.add_argument("--cull-keep", type=int, default=2, + help="number of top runs to finish when culling") p.add_argument("--lam", type=float, default=1.068, help="fixed kernel-weight lambda (lambda is irrelevant for autonomy; the SEED " "is what matters, so we fix lambda and select over seeds). Pass <=0 to " @@ -106,18 +116,14 @@ def main(): selection="autonomy", val_vocs=val_vocs, test_vocs=test_vocs, keep_const=args.keep_const, rescale_autonomy=True, lambdas=None if args.lam <= 0 else [args.lam], + cull_frac=args.cull_frac, cull_keep=args.cull_keep, ) # DEPLOYED generation: rescaled autonomous reconstruction of the held-out test vocs. # (Amplitude is a free gauge fixed at generation; selection above already used rescaled autonomy.) if test_vocs: - # actual selected lambda: argmax of mean validation autonomy (= args.lam when fixed; the - # auto-selected grid value when swept, rather than the <=0 sweep sentinel) - sel_lambda = float(args.lam) - csv_path = os.path.join(out_dir, "lambda_seed_cv.csv") - if args.lam <= 0 and os.path.isfile(csv_path): - dfc = pd.read_csv(csv_path) - sel_lambda = float(dfc.groupby("lambda")["val_autonomy"].mean().idxmax()) + # authoritative selected lambda from the selector (handles fixed-lambda, swept-grid, and culled) + sel_lambda = float(getattr(best_model, "_selected_lambda", args.lam)) score, _, bd = autonomy_score(best_model, test_vocs, dt, rescale=True) recon = generate_autonomous(best_model, test_vocs[0], dt, rescale=True) wav = (recon / (np.abs(recon).max() + 1e-12) * 0.95 * 32767).astype(np.int16) diff --git a/train/model_cv.py b/train/model_cv.py index 9e52007..f021e19 100644 --- a/train/model_cv.py +++ b/train/model_cv.py @@ -40,6 +40,8 @@ def model_cv_lambdas( keep_const: bool = False, rescale_autonomy: bool = False, lambdas: list = None, + cull_frac: float = 0.0, + cull_keep: int = 2, ) -> torch.nn.Module: """ This function trains models and cross-validates across regularization strengths. @@ -48,6 +50,12 @@ def model_cv_lambdas( on the most complex term is 10**4). Saves these in a larger folder, alongside train stats, training plots, etc. + Seed culling (cull_frac>0, selection='autonomy'): the autonomous-reconstruction quality is set by + the random seed and its ranking settles well before R^2 plateaus, so instead of training every + run fully, train all runs to `cull_frac` of n_epochs, rank by rescaled validation autonomy, and + finish only the top `cull_keep`. The best finished run is selected. Roughly halves the seed search + (see docs/autonomous_amplitude.md). With cull_frac=0 (default) every run is trained fully. + inputs ----- - dls: dictionary of dataloaders, train and test @@ -91,85 +99,117 @@ def model_cv_lambdas( "selection='autonomy' requires val_vocs (held-out vocalization segments)" ) - # train n_seeds models per lambda; record the validation metric(s) for each. Resumable: - # a (lambda, seed) with a complete checkpoint is loaded instead of retrained. - records = [] - for lam in lambdas: - for seed in range(n_seeds): - torch.manual_seed(seed) - np.random.seed(seed) - kernel = fullPolyModule( - nTerms=n_kernels, device="cuda", x_dim=1, z_dim=2, - activation=lambda x: x, lam=lam, - ) - model = Ouroboros( - d_data=1, n_layers=n_layers, d_state=d_state, d_conv=d_conv, - expand_factor=expand_factor, tau=tau, smooth_len=smooth_len, - kernel=kernel, drive_lowpass_ms=drive_lowpass_ms, keep_const=keep_const, - ) - opt = Adam(model.parameters(), lr=lr) - sched = ReduceLROnPlateau( - opt, factor=0.5, patience=max(n_epochs // 25, 2), min_lr=1e-10 - ) - run_dir = os.path.join(model_path, f"poly_lam{lam:.3f}_seed{seed}") - save_loc = os.path.join(run_dir, f"checkpoint_{n_epochs}.tar") + # --- per-(lambda, seed) training, resumable (a run with a checkpoint >= target is not retrained) --- + def _train_to(lam, seed, target): + """build/resume (lam, seed) and train to `target` epochs; return (model, run_dir).""" + torch.manual_seed(seed) + np.random.seed(seed) + kernel = fullPolyModule(nTerms=n_kernels, device="cuda", x_dim=1, z_dim=2, + activation=lambda x: x, lam=lam) + model = Ouroboros(d_data=1, n_layers=n_layers, d_state=d_state, d_conv=d_conv, + expand_factor=expand_factor, tau=tau, smooth_len=smooth_len, + kernel=kernel, drive_lowpass_ms=drive_lowpass_ms, keep_const=keep_const) + opt = Adam(model.parameters(), lr=lr) + sched = ReduceLROnPlateau(opt, factor=0.5, patience=max(n_epochs // 25, 2), min_lr=1e-10) + run_dir = os.path.join(model_path, f"poly_lam{lam:.3f}_seed{seed}") + start_epoch = 0 + if glob.glob(os.path.join(run_dir, "*.tar")): + model, opt, sched, start_epoch = load_model(run_dir) # resume + model.kernel.lam = float(lam) # load_model hardcodes lam=1; restore it + if start_epoch < target: + train(model, opt, loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), + loaders=dls, scheduler=sched, nEpochs=target, val_freq=1, runDir=run_dir, + dt=dt, vis_freq=0, smoothing=False, reg_weights=True, start_epoch=start_epoch, + save_freq=save_freq, model_info=model_info) + save_model(model, opt, os.path.join(run_dir, f"checkpoint_{target}.tar"), + n_layers=n_layers, d_state=d_state, expand_factor=expand_factor, d_conv=d_conv) + return model, run_dir - start_epoch = 0 - if glob.glob(os.path.join(run_dir, "*.tar")): - model, opt, sched, start_epoch = load_model(run_dir) # resume - model.kernel.lam = float(lam) # load_model hardcodes lam=1; restore it - if start_epoch < n_epochs: - train( - model, opt, - loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), - loaders=dls, scheduler=sched, nEpochs=n_epochs, val_freq=1, - runDir=run_dir, dt=dt, vis_freq=0, smoothing=False, - reg_weights=True, start_epoch=start_epoch, save_freq=save_freq, - model_info=model_info, - ) - save_model(model, opt, save_loc, n_layers=n_layers, d_state=d_state, - expand_factor=expand_factor, d_conv=d_conv) + def _score(model): + model.eval() + with torch.no_grad(): + (_, vr2), _, _ = eval_model_error(dls, model, dt=dt, comparison="val") + if selection == "autonomy": + va, _, bd = autonomy_score(model, val_vocs, dt, rescale=rescale_autonomy) + else: + va, bd = np.nan, None + return vr2, va, bd - model.eval() - with torch.no_grad(): - (_, val_r2), _, _ = eval_model_error(dls, model, dt=dt, comparison="val") - val_auto = np.nan - if selection == "autonomy": - val_auto, _, bd = autonomy_score(model, val_vocs, dt, rescale=rescale_autonomy) - print(f"lam={lam:.3f} seed={seed}: val R2={val_r2:.4f} val autonomy={val_auto:+.4f} " - f"(spec={bd['spec_corr']:.2f} amp_pen={bd['amp_pen']:.2f} " - f"pitch_pen={bd['pitch_pen']:.2f} bounded={bd['bounded_frac']:.2f})", flush=True) - else: - print(f"lam={lam:.3f} seed={seed}: val R2={val_r2:.4f}", flush=True) - records.append({"lambda": lam, "seed": seed, "val_r2": val_r2, - "val_autonomy": val_auto, "ckpt": run_dir}) - del model, opt, kernel - gc.collect() - torch.cuda.empty_cache() + n_jobs = n_seeds * len(lambdas) + do_cull = (0.0 < cull_frac < 1.0) and selection == "autonomy" and cull_keep < n_jobs + records = [] + if do_cull: + # SEED CULLING: train every run to a cull epoch, rank by rescaled val autonomy, then finish + # only the top `cull_keep`. The autonomy ranking is ~settled well before R^2 plateaus + # (see docs/autonomous_amplitude.md), so this finds the best seed at a fraction of the cost. + cull_epoch = max(1, int(round(cull_frac * n_epochs))) + print(f"\n=== seed culling: train all {n_jobs} run(s) to epoch {cull_epoch} " + f"({cull_frac:.0%} of {n_epochs}), then finish top {cull_keep} by rescaled val autonomy ===", + flush=True) + ranked = [] + for lam in lambdas: + for seed in range(n_seeds): + model, run_dir = _train_to(lam, seed, cull_epoch) + _, va, _ = _score(model) + ranked.append((va, lam, seed)) + print(f" [cull@{cull_epoch}] lam={lam:.3f} seed={seed}: val autonomy={va:+.4f}", flush=True) + del model; gc.collect(); torch.cuda.empty_cache() + ranked.sort(key=lambda r: (r[0] if np.isfinite(r[0]) else -np.inf), reverse=True) + keep, culled = ranked[:cull_keep], ranked[cull_keep:] + print(" keep: " + ", ".join(f"lam{l:.3f}/seed{s}({a:+.3f})" for a, l, s in keep), flush=True) + print(" cull: " + ", ".join(f"lam{l:.3f}/seed{s}({a:+.3f})" for a, l, s in culled), flush=True) + for _, lam, seed in keep: + model, run_dir = _train_to(lam, seed, n_epochs) + vr2, va, _ = _score(model) + print(f" [final] lam={lam:.3f} seed={seed}: val R2={vr2:.4f} val autonomy={va:+.4f}", flush=True) + records.append({"lambda": lam, "seed": seed, "val_r2": vr2, "val_autonomy": va, "ckpt": run_dir}) + del model; gc.collect(); torch.cuda.empty_cache() + else: + for lam in lambdas: + for seed in range(n_seeds): + model, run_dir = _train_to(lam, seed, n_epochs) + vr2, va, bd = _score(model) + if selection == "autonomy": + print(f"lam={lam:.3f} seed={seed}: val R2={vr2:.4f} val autonomy={va:+.4f} " + f"(spec={bd['spec_corr']:.2f} amp_pen={bd['amp_pen']:.2f} " + f"pitch_pen={bd['pitch_pen']:.2f} bounded={bd['bounded_frac']:.2f})", flush=True) + else: + print(f"lam={lam:.3f} seed={seed}: val R2={vr2:.4f}", flush=True) + records.append({"lambda": lam, "seed": seed, "val_r2": vr2, "val_autonomy": va, "ckpt": run_dir}) + del model; gc.collect(); torch.cuda.empty_cache() df = pd.DataFrame(records) metric = "val_autonomy" if selection == "autonomy" else "val_r2" - per_lam = df.groupby("lambda")[metric].agg(["mean", "std"]) - best_lambda = float(per_lam["mean"].idxmax()) - print(f"\n=== lambda selection by {selection} (mean over {n_seeds} seed(s)) ===", flush=True) - for lam, row in per_lam.iterrows(): - mark = " <-- selected" if abs(lam - best_lambda) < 1e-9 else "" - sd = 0.0 if np.isnan(row["std"]) else row["std"] - print(f" lambda={lam:.3f}: {metric}={row['mean']:+.4f} +- {sd:.4f}{mark}", flush=True) df.to_csv(os.path.join(model_path, "lambda_seed_cv.csv"), index=False) - plt.figure() - plt.errorbar(per_lam.index, per_lam["mean"], per_lam["std"].fillna(0.0), marker="o") - plt.axvline(best_lambda, ls="--", color="0.6") - plt.xlabel(r"kernel-weight $\lambda$") - plt.ylabel(f"validation {metric}") - plt.title(f"lambda selection by {selection} ({n_seeds} seeds)") - plt.savefig(os.path.join(model_path, f"lambda_selection_{selection}.svg")) - plt.close() + if do_cull: + # only the kept runs are finished, so select the single best finished run (a per-lambda + # mean is not meaningful when most seeds were culled). + best_row = df.sort_values(metric).iloc[-1] + best_lambda = float(best_row["lambda"]) + best_ckpt = best_row["ckpt"] + print(f"\n=== culled selection: best finished run lam={best_lambda:.3f} " + f"seed={int(best_row['seed'])} {metric}={best_row[metric]:+.4f} ===", flush=True) + else: + per_lam = df.groupby("lambda")[metric].agg(["mean", "std"]) + best_lambda = float(per_lam["mean"].idxmax()) + print(f"\n=== lambda selection by {selection} (mean over {n_seeds} seed(s)) ===", flush=True) + for lam, row in per_lam.iterrows(): + mark = " <-- selected" if abs(lam - best_lambda) < 1e-9 else "" + sd = 0.0 if np.isnan(row["std"]) else row["std"] + print(f" lambda={lam:.3f}: {metric}={row['mean']:+.4f} +- {sd:.4f}{mark}", flush=True) + plt.figure() + plt.errorbar(per_lam.index, per_lam["mean"], per_lam["std"].fillna(0.0), marker="o") + plt.axvline(best_lambda, ls="--", color="0.6") + plt.xlabel(r"kernel-weight $\lambda$") + plt.ylabel(f"validation {metric}") + plt.title(f"lambda selection by {selection} ({n_seeds} seeds)") + plt.savefig(os.path.join(model_path, f"lambda_selection_{selection}.svg")) + plt.close() + # best model = best seed at the selected lambda (by the same validation metric) + best_ckpt = df[np.isclose(df["lambda"], best_lambda)].sort_values(metric).iloc[-1]["ckpt"] - # best model = best seed at the selected lambda (by the same validation metric) - cand = df[np.isclose(df["lambda"], best_lambda)].sort_values(metric) - best_model, _, _, _ = load_model(cand.iloc[-1]["ckpt"]) + best_model, _, _, _ = load_model(best_ckpt) best_model.eval() with torch.no_grad(): (_, test_r2), _, _ = eval_model_error(dls, best_model, dt=dt, comparison="test") @@ -178,6 +218,7 @@ def model_cv_lambdas( test_auto, _, _ = autonomy_score(best_model, test_vocs, dt, rescale=rescale_autonomy) print(f"BEST lambda={best_lambda:.3f}: test R2={test_r2:.4f} test autonomy={test_auto:+.4f}", flush=True) + best_model._selected_lambda = best_lambda # authoritative selected lambda for callers/manifests return best_model From 0199378e8e270f213bfb763e1ef8aed36a7b6596 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Sat, 30 May 2026 17:57:33 -0400 Subject: [PATCH 12/15] docs+diag: small-k rollout plan and per-voc amplitude scan scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit docs/small_k_rollout_plan.md: design doc for the next attempt at the cold-start free-amplitude problem (see docs/autonomous_amplitude.md). Two complementary changes: (1) small-k rollout consistency in the train loss (utils.euler_step_k, k << one carrier cycle), (2) selection by cold-start raw autonomy (no rescale, IC at silence). Includes the experimental matrix (4 rows × 16 seeds) and the AGC runtime fallback (Appendix A.3) for if the trained model can't carry amplitude on its own. examples/scan_seed_amp_{coldstart,stability}.py: per-voc rescale-factor scans across the existing 8-seed pool. coldstart variant integrates from silence IC on lead-in+voc windows (the situation finchsim deployment faces); stability variant uses mid-voc start. Both compute std-ratio rescale, CV, max/min, and spec_corr. Referenced by §6 of the plan as the binding-criterion diagnostic. Co-Authored-By: Claude Opus 4.7 --- docs/small_k_rollout_plan.md | 363 ++++++++++++++++++++++++++++ examples/scan_seed_amp_coldstart.py | 117 +++++++++ examples/scan_seed_amp_stability.py | 119 +++++++++ 3 files changed, 599 insertions(+) create mode 100644 docs/small_k_rollout_plan.md create mode 100644 examples/scan_seed_amp_coldstart.py create mode 100644 examples/scan_seed_amp_stability.py diff --git a/docs/small_k_rollout_plan.md b/docs/small_k_rollout_plan.md new file mode 100644 index 0000000..0633817 --- /dev/null +++ b/docs/small_k_rollout_plan.md @@ -0,0 +1,363 @@ +# Plan: small-k rollout in training + cold-start raw-autonomy seed selection + +**Status:** Draft handoff. The current ouroboros pipeline trains with a single-step +ẍ-prediction loss and selects seeds by **rescaled** (amplitude-gauge-removed) autonomy. +That combination has been shown — by experiment in this branch (see +`docs/autonomous_amplitude.md`) and by direct measurement in finchsim — to leave +autonomous amplitude marginal in a way that breaks the deployed downstream synthesis. +This plan proposes the two complementary changes that line up with the unresolved +failure mode, and a concrete experimental matrix. + +The complementary work in `finchsim` is described at the end (Appendix B): the synthesis +side is already implemented and validated. What is missing is a model whose drives carry +amplitude information from a cold start; this plan is how to get one. + +--- + +## 1. Why this exists + +The deployed pipeline is `finchsim`'s **brainstem → adapter → ouroboros control signals +→ synthesis ODE → audio**. The synthesis ODE is the full driven polynomial form + +``` +dy/ds = v +dv/ds = -ω(t)² y - γ(t) v - Σ_{p,k} W[p,k] y^p v^k (W[1,0]=W[0,1]=0; W[0,0] kept) +``` + +stepped with fixed-step RK4 in rescaled time `s = t/τ`, **starting from rest** at +`y(0) ≈ 0`. The drives ω, γ, W come from a calibrated linear adapter mapping brainstem +LP-filtered spike rates to ouroboros control channels. + +The problem (proven empirically in May 2026 across the existing 8-seed cull at λ=1.068): +**the raw autonomous amplitude of the trained model is wildly voc-dependent from a cold +start.** Concretely, with the seed currently in deployment (`poly_seedcull/poly_lam1.068_seed6`): + +- Cold-start raw rescale factor `r(v) = std(correct(target)) / std(correct(raw_auto))` + has **range 633× across 40 vocs**, median 32, CV 151% + (`examples/scan_seed_amp_coldstart.py`). +- Other seeds in the pool either blow up cold-start (Λ>0, seeds 2/3/4) or are similarly + wild (seeds 0/1/5/6) — except seed 7 at CV 53% / max-min 48×, still not deployable. +- With the seed's true drives fed in, the finchsim integration test produces + `std=0.0000` (silence): the deterministic poly ODE decays from rest, and the + raw amplitude is too voc-variable for a single shipped rescale constant to fix. + +The Floquet diagnostic in `docs/autonomous_amplitude.md` already explains why: the +trained model is **near-neutral** (Λ/cycle ≈ 0), so amplitude is set by the IC + the +drive trajectory **per vocalization**, not by a contracting attractor — and the current +training objective doesn't pressure the drives to encode amplitude (one-step ẍ matching +on the limit cycle is invariant to it). + +## 2. What's been tried, and what hasn't + +See `docs/autonomous_amplitude.md` for the full table. Briefly, what failed and what +the failure means for the design of the next attempt: + +| Approach | Horizon | Outcome | Why it failed | +|---|---|---|---| +| Pointwise rollout MSE (post-hoc FT) | 80–400 samples | Autonomy worse; amp_pen 0.47→1.59 | One carrier cycle ≈ 16 samples; phase drift past that makes pointwise MSE reward amplitude collapse. | +| Spectral + envelope rollout FT (`rollout_refine`) | 768–1500 | Seed-variable; mean down, var up | Phase-invariant but BPTT through unstable rollout, gradient pathology. | +| Λ-penalty (`floq`) | – | Decays to wrong attractor | Stability **type** ≠ attractor **location**. | +| Short noise FT | 12–24 | Λ→+30 blow-up | BPTT through expanding rollout. | +| Selection by **rescaled** autonomy | – | Picks Λ≈0 seed → cold-start unstable | Removes amplitude from the criterion; lets drives stay amp-uninformative. | + +**Not yet tried and theoretically well-placed:** +1. **k-step rollout *in training* (not as FT)** with **very small k**, k ≪ one carrier + cycle, so pointwise (y, dy) MSE is a valid signal **and** the model never sees the + long-horizon BPTT pathology. The primitive (`utils.euler_step_k`) is already in the + repo but **not wired into the training loss** — it's a relic of the predecessor + `mdmarti/ouro_clean`. +2. **Selection on cold-start raw autonomy** (no rescale, IC at silence start), which + makes the selection criterion actually test what the deployed pipeline needs: + drives carry amplitude from t=0. + +## 3. Proposal + +### 3.1 Training-side change: small-k rollout consistency in the loss + +Wire `utils.euler_step_k(y, dy, d2y, dt, k)` into the train loop as an additional +loss term, summed with the existing one-step ẍ-prediction MSE: + +``` +L = L_one_step + λ_k · L_k_step +L_k_step = MSE(y, ŷ_k) + MSE(dy, d̂y_k) # pointwise over all (B, L−k, ·) +``` + +`euler_step_k` returns `((y_out, yhat_out), (dy_out, dyhat_out))`, stacking the k +intermediate predictions on the last dim, so summing over k inside the MSE is "all +intermediate steps" — the default and the desired pressure (consistency at *every* +step from 1..k, not just the kth). + +**Phase drift consideration.** At sr=40 kHz and a typical syllable carrier near +~2.5 kHz, one cycle is ~16 samples. Choose k ≪ 16. The matrix below tests k ∈ {2, 4, 8}. + +**Curriculum?** Probably no curriculum is needed at these horizons; the rollout is +short enough that gradients are well-conditioned from epoch 1. Start with **fixed k**; +if loss balance turns out fragile, fall back to a short curriculum (k=2 for first ~10 +epochs, then k=target). + +**Loss weight.** Suggest sweeping λ_k ∈ {0.1, 0.3, 1.0}. The one-step loss must remain +the dominant signal; the k-step term is a consistency *pressure*, not a replacement. + +### 3.2 Selection-side change: cold-start raw autonomy + +Extend `train.eval.autonomy_score` (or write a sibling) to score from **cold start**: + +- IC at the start of the silence lead-in: `y(0) = audio[0]` (≈ 0), `v(0) = (τ/dt)·dxdt[0]`. + This is the convention used by `train.eval.integrate_poly_autonomous` already. +- **Window = silence_pad + vocalization** (the `make_paired_data_v2.py` convention, + with SILENCE_PAD=2000 for 50 ms lead-in at 40 kHz). +- **No rescaling** (`rescale=False`), so amplitude is part of the score. +- Score = `spec_corr − w_amp·|log(std_auto/std_tgt)| − w_pitch·|log(pitch_auto/pitch_tgt)|`, + diverge → `diverge_score`. + +This is exactly the form of `autonomy_score(..., rescale=False)`; the change is +**windowing** (full window incl. lead-in, not mid-voc), and using this metric in +`train.model_cv.model_cv_lambdas(selection="autonomy")` as the cull criterion. + +The diagnostic at `examples/scan_seed_amp_coldstart.py` already implements the +windowing and IC; the metric just needs to be folded into the pipeline alongside it. + +### 3.3 Seed-pool change + +Train **≥16 fresh seeds at λ=1.068** (the existing seedcull's chosen λ; this is +known irrelevant for autonomy, so don't bother sweeping). Compute cold-start raw +autonomy at epoch 20 and at epoch 50. + +Per `examples/seed_cull_test.py`, **rankings stabilize by epoch 20** for the +rescaled metric (Spearman ρ ≈ 0.86). Re-run that test on the **cold-start raw** +metric to validate the same early-cull schedule (it may or may not be as stable); +default to: train all to ep 20, cull to top-3, finish those to ep 50. + +## 4. Concrete steps + +### Step 1 — Wire `euler_step_k` into the train loop + +**Files:** +- `utils.py:45` — `euler_step_k(y, dy, d2y, dt, k=1)`, already returns the (y, dy) + ground truth + predictions for steps 1..k. +- `train/train.py` — current loss is the one-step ẍ MSE inside the train step. Add + the k-step term as a sibling. The model already outputs `yhat` (= τ²·d²y/dt²), so + compute `d2y_phys = yhat / τ²` and pass to `euler_step_k`. + +Add CLI: +- `--k-rollout` (int, default 0 = off; ≥2 enables) +- `--lambda-k` (float, default 0.3) + +Don't change the default behavior — keep one-step-only as the default so the existing +pipeline is undisturbed. + +### Step 2 — Add cold-start raw autonomy as a selection metric + +**Files:** +- `train/eval.py` — `autonomy_score` already takes `rescale: bool`. Add a + parameter `cold_start: bool = False` that, when True, expects segments to **include + the silence lead-in** and does NOT trim — uses the supplied window as-is. (The + current convention is mid-voc start +50 ms.) +- `train/model_cv.py` — `model_cv_lambdas(..., rescale_autonomy=False, + cold_start_autonomy=True, val_vocs=..., ...)`. Pass-through. + +Score signature stays the same: `(score, per_seg_scores, breakdown)`. + +### Step 3 — Multi-seed training run + +**Files:** +- `examples/run_lambda_pipeline.py` — already supports + `--n-seeds`, `--lam`, `--keep-const`, `--drive-lowpass-ms`. Add `--k-rollout + --lambda-k --cold-start-selection`. With `--lam 1.068 --n-seeds 16`, this trains + 16 seeds and selects by the chosen metric. +- Train to epoch 20 first (`--n-epochs 20`); save checkpoints at 10 and 20. +- Score all 16 at epoch 20 on cold-start raw autonomy over the val shard + (data500/gabo_p8). Cull to top-3. +- Resume the top-3 to epoch 50 (`--n-epochs 50`, the run is resumable per λ-seed + per `train.model_cv`). + +### Step 4 — Final evaluation + +For each of the top-3 finished seeds: +- Held-out cold-start raw autonomy on data500/gabo_p9 (per-voc spread; want CV<30%). +- Per-voc rescale factor distribution (`examples/scan_seed_amp_coldstart.py`). +- spec_corr, pitch_pen, bounded_frac. + +Pick the **winner** as the seed with the best joint criterion. Ship it. + +### Step 5 — Validate in finchsim + +(See Appendix B for current finchsim state.) +- Re-export paired data with the winning seed + (`~/ouroboros_smoke/make_calib_data.py --ckpt `). +- Re-run the adapter pipeline: + `scripts/gen_adapter_features.py` → `scripts/fit_adapter.py` (the per-pool muscle + filter at 3 ms/40 ms is already wired up). +- Re-run `scripts/simulate_hvc_to_audio.py`. **Success criterion: generated audio is + not silent and has energy in the 1–8 kHz band** (see `outputs/hvc_to_audio_comparison.png`). +- If yes: ship the single shipped rescale constant (median of the winning seed's + per-voc factors) with the model. No runtime AGC, no gamma-centering. +- If no: see §6. + +## 5. Experimental matrix + +Each row = one full pipeline (train 16 seeds × cull → 3 → finish → score). +Total ~16 trainings × 4 rows = 64 model fits. + +| run | k_rollout | λ_k | selection | rationale | +|---|---|---|---|---| +| baseline | off | – | cold-start raw | isolate selection-only contribution | +| k2 | 2 | 0.3 | cold-start raw | minimal multi-step pressure | +| k4 | 4 | 0.3 | cold-start raw | quarter-cycle horizon | +| k8 | 8 | 0.3 | cold-start raw | half-cycle horizon; near phase-drift onset | + +If `baseline` alone clears success criteria, the selection change was enough. If `k8` +clears it but smaller k don't, that's evidence the multi-step pressure was needed. +Sweep λ_k only if a k value looks promising but is dominated by the one-step loss. + +**Compute estimate.** At λ=1.068, d_state=4, 50 vocs, 1 ms low-pass, 50 epochs: +~15–30 min per seed depending on GPU contention (per `poly-lowpass-autonomous` +memory). 64 seeds × 25 min = ~26 GPU-hours. With the ep20 → ep50 cull, halve to +~13 hours. Parallelize across rows if multiple GPUs are available. + +## 6. Success criteria, failure plan + +**Per seed (held-out, cold-start, 10+ vocs from `data500/gabo_p9` with 50 ms lead-in):** +- bounded_frac = 1.0 (no NaN/Inf during integration) +- spec_corr ≥ 0.55 (the current pool's best cold-start spec_corr is seed 0 at 0.736; + 0.55 is a modest bar) +- pitch_pen ≤ 0.10 (≈ 10% pitch error) +- **per-voc rescale-factor CV ≤ 30%** (this is the binding constraint; + the current best — seed 7 — is at 53%) +- amp_pen ≤ 0.5 (`|log(std_auto/std_tgt)|` ≤ 0.5 ⇔ within 1.65× of correct + amplitude with a single shipped constant) + +**Deployment criterion (finchsim integration test):** +- `simulate_hvc_to_audio.py` produces generated audio with `std > 1e-4` and + spectral energy concentrated in 1–8 kHz (per the existing comparison plot). + +**Failure plan.** If no seed across all 4 experimental rows meets the per-voc CV +criterion: +1. This is empirical evidence the marginal-amplitude problem is **structural** beyond + selection + small-k pressure. Document the result in + `docs/autonomous_amplitude.md`. +2. Fall back to runtime AGC in `finchsim/ouroboros_ode.py` (see §A.3 in + Appendix B), shipping seed 7 (best amp stability we have). + +## 7. Pitfalls to avoid (from prior attempts) + +- **Don't** use horizons that span or exceed one carrier cycle (~16 samples at this + sr). Phase drift past that makes pointwise MSE pathological — it rewards amplitude + collapse to minimize the phase-drifted squared error. +- **Don't** use rollout as a post-hoc FT on a converged TF model with these dynamics. + The model is near-marginal; BPTT through an expanding rollout has exploding + gradients, and large injected noise overrides to over-contraction. Rollout pressure + belongs *during* TF training, with short k. +- **Don't** pursue Λ-penalty fixes — they change stability *type*, not attractor + *location* (see "Both Λ-targeting fixes FAIL" in `docs/autonomous_amplitude.md`). +- **Don't** select by rescaled autonomy if the deployment target is real-time + cold-start synthesis. That's the wrong metric for the wrong endpoint. + +## 8. File pointers + +- **Primitive (existing):** `utils.py:45` — `euler_step_k`. +- **Train loop:** `train/train.py` — where the single-step ẍ MSE lives. +- **Score:** `train/eval.py` — `autonomy_score`, `integrate_poly_autonomous`, + `generate_autonomous`. +- **Selection pipeline:** `train/model_cv.py` — `model_cv_lambdas` (multi-seed, + selection metric pass-through). +- **Entrypoint:** `examples/run_lambda_pipeline.py`. +- **Seed cull primitive:** `examples/seed_cull_test.py` (currently uses rescaled + metric; reuse the early-checkpoint timing analysis with the cold-start metric). +- **Diagnostics (already in repo as of this draft):** + - `examples/scan_seed_amp_stability.py` — per-voc rescale CV, mid-voc. + - `examples/scan_seed_amp_coldstart.py` — per-voc rescale CV, cold start. +- **Background docs:** + - `docs/autonomous_amplitude.md` — the full free-amplitude analysis + Floquet diagnostic. + - `docs/ode_solver_advice.md` (in finchsim) — DC drift / streaming integrator design. + - `docs/RA_to_ouroboros_report.md` (in finchsim) — biology + adapter design. +- **Project memory:** + `~/.claude/projects/-home-pearson-code-ouroboros/memory/poly-lowpass-autonomous.md` — + detailed history of the rescale + rollout fine-tune attempts in this branch. + +## 9. Git hygiene + +- Branch off the current ouroboros HEAD (currently `arneodo-parameterization` or + whatever's current — check `git log -1`). +- Don't merge into main until §5 integration test passes. +- Commit messages should end with `Co-Authored-By: Claude Opus 4.7 `. + +--- + +## Appendix A: known good values (for sanity-checking new seeds) + +From the existing 8-seed pool at λ=1.068 (`poly_seedcull/`): + +**Mid-voc rescaled autonomy** (current selection metric; held out test on gabo_p9): +- best-of-8: seed6 = +0.65 (the shipped seed). Selection record: + `poly_seedcull/lambda_seed_cv.csv` and `selected_model.json`. + +**Mid-voc per-voc rescale factor** (10 vocs, gabo_p9, start +50 ms): +- seed 7: **CV 18%**, max/min 1.6×, spec_corr 0.636, median 1.26 +- seed 5: CV 23%, max/min 1.8×, spec_corr 0.631 +- seed 6: CV 24%, max/min 2.0×, spec_corr 0.660 ← current shipped seed +- seed 3: CV 26%, max/min 2.2×, spec_corr 0.583 + +**Cold-start per-voc rescale factor** (10 vocs, 50 ms lead-in + voc): +- seed 7: CV 53%, max/min 48×, spec_corr 0.504 ← best of the bunch, still not deployable +- seed 0: CV 80%, max/min 9×, spec_corr 0.736 ← best cold-start spec +- seed 1: CV 78%, max/min 181×, spec_corr 0.438 +- seed 6: CV 144%, max/min 431×, spec_corr 0.494 ← current shipped seed (fails) +- seeds 2/3/4: blow up (non-finite rollout) + +The plan succeeds when a new seed's **cold-start** numbers look like +seed 7's mid-voc numbers: CV under 30%, max/min under 3×, spec_corr above 0.55. + +## Appendix B: state of the finchsim side (as of this handoff) + +The synthesis-side work is **done and validated**, on branch +`finchsim/ouroboros-poly-synthesis` (off updated `main`, base ~9118523): + +1. `ouroboros_ode.py` rewritten — `build_synthesis_ode(mode="poly")` (default) + integrates the **full polynomial autonomous ODE** using ω, γ, and all 256 + `poly_coeffs` from the adapter's `u` channels. RK4, drives held per tick. + Mask: (1,0)/(0,1) zeroed, (0,0) kept (matches `keep_const=True`). Validation + via `scripts/validate_poly_synthesis.py`: corr **0.999** vs ouroboros + `integrate_poly_autonomous` on all 10 v2 paired-data renditions (pitch exact, + amplitude matched after the same RMS rescale). Brian wiring smoke test: + corr 0.996 vs the numpy loop at the expected 1-tick latch. +2. A gentle output one-pole DC blocker is **kept available** (`dc_block=True`, + τ=30 ms) per the user; with a stable seed it's near-no-op. +3. **Per-pool muscle filter wired up** in both `adapter.py` and + `scripts/gen_adapter_features.py`: `xii_vs/xii_ad: 3 ms` (Adam & Elemans + syringeal twitch), `ram/pam: 40 ms` (respiratory). Map lives in + `adapter.MUSCLE_TAU_MS_BY_POOL`, imported by the feature script so the + filters can never drift apart. +4. `data/paired_data.npz` in finchsim is the **40-rendition drives-only** + export (50 ms lead-in) generated by `~/ouroboros_smoke/make_calib_data.py` + from the current shipped seed. **Re-run that script with the new seed's + checkpoint** to refresh `data/paired_data.npz` before re-running the + calibration. +5. The poly synthesis is signature-compatible with the existing caller + `scripts/simulate_hvc_to_audio.py` (mode defaults to poly; no caller changes + needed). Re-run the integration test after the new adapter weights are fit. + +**Outstanding finchsim concern:** the per-pool filter helped γ (mid-voc held-out +R² 0.484 → 0.509, clears the README ≥0.5 bar) but didn't move ω (~0.38); +the poly held-out R² is ~0.06 regardless of data/filter — the 256 poly channels are +genuinely not linearly predictable from 4 brainstem pools. That's a limitation but +**not the audibility blocker**: LASSO falls back to each channel's mean, so the +nonlinear limit-cycle terms are still present at deployment. The audibility blocker +is solely the free-amplitude problem this plan addresses. + +### A.3 If the plan fails — runtime fallback + +If §5 fails (no seed reaches the cold-start criterion), add an AGC to +`finchsim/ouroboros_ode.py`: + +```python +# inside the network_operation, after computing the next (y, v) +ms2[0] = (1 - α) * ms2[0] + α * y * y # EMA of y², τ_agc ≈ 30–100 ms +gain = target_rms / max(sqrt(ms2[0]), gate_floor) +G_ode.y[0] = y * gain +``` + +Knobs: `target_rms` (≈ 0.01, the v2 audio scale), `τ_agc` (30–100 ms), +`gate_floor` (noise gate so silence stays silent). This is the streaming analog of +`generate_autonomous(rescale=True)`. It's seed-agnostic — works with any seed that +integrates finite, including the current seed 6. diff --git a/examples/scan_seed_amp_coldstart.py b/examples/scan_seed_amp_coldstart.py new file mode 100644 index 0000000..19e9af5 --- /dev/null +++ b/examples/scan_seed_amp_coldstart.py @@ -0,0 +1,117 @@ +"""Cold-start amplitude stability scan for 8 seedcull seeds. + +Same diagnostic as scan_seed_amp_stability.py but using the full window +[silence_lead_in + vocalization] from a near-rest IC -- the situation finchsim's +simulate_hvc_to_audio actually faces (silence -> ignition -> sustained oscillation). +Reports per-seed CV of the rescale factor, range, median, and spec_corr. + +A seed that is amplitude-stable here would let you ship one global rescale and +nothing else for the streaming synthesis. +""" +import os, sys, glob +sys.path.insert(0, "/home/pearson/code/ouroboros") + +import numpy as np +import torch +from scipy.io import wavfile +from scipy.signal import butter, sosfiltfilt, welch + +from utils import deriv_approx_dy +from train.train import load_model + + +def _poly_rk4_step(y, v, omega, gamma, W, powers, h=1.0): + def rhs(yy, vv): + kern = float((yy ** powers) @ (W @ (vv ** powers))) + return vv, -(omega * omega) * yy - gamma * vv - kern + k1y, k1v = rhs(y, v) + k2y, k2v = rhs(y + 0.5 * h * k1y, v + 0.5 * h * k1v) + k3y, k3v = rhs(y + 0.5 * h * k2y, v + 0.5 * h * k2v) + k4y, k4v = rhs(y + h * k3y, v + h * k3v) + y = y + (h / 6.0) * (k1y + 2 * k2y + 2 * k3y + k4y) + v = v + (h / 6.0) * (k1v + 2 * k2v + 2 * k3v + k4v) + return y, v + + +CKPT_BASE = "/home/pearson/code/ouroboros/poly_seedcull" +DATA_DIR = "/home/pearson/code/ouroboros/data500/gabo_p9" +N_VOCS = 10 +SILENCE_PAD = 2000 # 50 ms lead-in (matches finchsim/data/paired_data.npz) +powers = np.arange(16) + + +def correct(x): + fs = len(x); sos = butter(5, 100/(0.5*fs), btype="low", output="sos") + return x - sosfiltfilt(sos, x) + + +def logpsd(x, fs): + f, P = welch(x - x.mean(), fs=fs, nperseg=min(1024, len(x))) + m = f <= 8000 + return np.log(P[m] + 1e-20) + + +# Full-window cold-start segments: SILENCE_PAD lead-in + vocalization +wavs = sorted(glob.glob(os.path.join(DATA_DIR, "*.wav")))[:N_VOCS] +segs = [] +for w in wavs: + sr, aud = wavfile.read(w); aud = aud.astype(np.float64) + onoffs = np.atleast_2d(np.loadtxt(w.replace(".wav", ".txt"))) + on_i, off_i = int(round(onoffs[0][0] * sr)), int(round(onoffs[0][1] * sr)) + start = max(0, on_i - SILENCE_PAD) + segs.append(aud[start:off_i]) +dt = 1.0 / sr +L = min(len(s) for s in segs) +segs = [s[:L] for s in segs] +print(f"{len(segs)} held-out vocs, L={L} ({L*dt*1e3:.0f} ms; lead-in {SILENCE_PAD*dt*1e3:.0f} ms), " + f"COLD START (IC=audio[0]≈silence)\n", flush=True) + +print(f"{'seed':>4} {'n':>3} {'CV%':>6} {'median':>10} {'min':>10} {'max':>10} " + f"{'max/min':>8} {'spec_corr':>10}", flush=True) + +for seed in range(8): + cd = f"{CKPT_BASE}/poly_lam1.068_seed{seed}" + if not os.path.isdir(cd): + print(f"{seed:>4} (no dir)"); continue + model, _, _, _ = load_model(cd); model.eval() + tau = float(model.tau) + + factors, corrs = [], [] + for audio in segs: + x = audio[None, :, None] + dy = deriv_approx_dy(x) + xt = torch.from_numpy(x).to(torch.float32).cuda() + dt_t = torch.from_numpy(dy).to(torch.float32).cuda() + with torch.no_grad(): + om, ga, _, wts, _ = model.get_funcs(xt, dt_t.clone(), dt, smoothing=False) + om = om.cpu().numpy()[0, :, 0] + ga = ga.cpu().numpy()[0, :, 0] + co = wts.cpu().numpy()[0] # (L, 16, 16) + + # cold start: IC at window start (= silence, near 0); matches generate_autonomous + y, v = float(audio[0]), float((tau / dt) * dy[0, 0, 0]) + raw = np.empty(L); raw[0] = y; bad = False + for t in range(L - 1): + W = co[t].copy(); W[1, 0] = 0; W[0, 1] = 0 + y, v = _poly_rk4_step(y, v, float(om[t]), float(ga[t]), W, powers, 1.0) + if not (np.isfinite(y) and np.isfinite(v)): + bad = True; break + raw[t + 1] = y + if bad: + continue + ca, cr = correct(audio), correct(raw) + s_t, s_a = float(np.std(ca)), float(np.std(cr)) + if s_a < 1e-12: + continue + r = s_t / s_a + factors.append(r) + cr_rs = cr * r + cc = float(np.corrcoef(logpsd(ca, 1/dt), logpsd(cr_rs, 1/dt))[0, 1]) + corrs.append(cc) + if not factors: + print(f"{seed:>4} 0 (no usable vocs)"); continue + arr = np.array(factors) + cv = arr.std() / arr.mean() * 100 + print(f"{seed:>4} {len(factors):>3d} {cv:>5.1f}% {np.median(arr):>10.3f} " + f"{arr.min():>10.3f} {arr.max():>10.3f} {arr.max()/arr.min():>7.1f}x " + f"{np.mean(corrs):>10.3f}", flush=True) diff --git a/examples/scan_seed_amp_stability.py b/examples/scan_seed_amp_stability.py new file mode 100644 index 0000000..49be181 --- /dev/null +++ b/examples/scan_seed_amp_stability.py @@ -0,0 +1,119 @@ +"""Scan 8 seedcull seeds (λ=1.068) for amplitude stability across vocs. + +For each seed, run get_funcs + my numpy poly-RK4 (NO rescale) on N held-out +mid-vocalization windows from data500/gabo_p9 (same start-offset / length the +autonomy_score uses). Per voc, compute r(v) = std(correct(input))/std(correct(raw)). +Report per-seed CV, median, min/max, max/min ratio, plus mean spectral (log-PSD) +correlation after rescale -- so we don't pick an amplitude-stable but spectrally-wrong +seed. Low CV (< ~20%) and decent spec_corr (> 0.5) => one shipped scalar would work. +""" +import os, sys, glob +sys.path.insert(0, "/home/pearson/code/ouroboros") +sys.path.insert(0, "/home/pearson/code/finchsim") + +import numpy as np +import torch +from scipy.io import wavfile +from scipy.signal import butter, sosfiltfilt, welch + +from utils import deriv_approx_dy +from train.train import load_model + + +def _poly_rk4_step(y, v, omega, gamma, W, powers, h=1.0): + """Inlined copy of ouroboros_ode._poly_rk4_step (numpy, no Brian dep).""" + def rhs(yy, vv): + kern = float((yy ** powers) @ (W @ (vv ** powers))) + return vv, -(omega * omega) * yy - gamma * vv - kern + k1y, k1v = rhs(y, v) + k2y, k2v = rhs(y + 0.5 * h * k1y, v + 0.5 * h * k1v) + k3y, k3v = rhs(y + 0.5 * h * k2y, v + 0.5 * h * k2v) + k4y, k4v = rhs(y + h * k3y, v + h * k3v) + y = y + (h / 6.0) * (k1y + 2 * k2y + 2 * k3y + k4y) + v = v + (h / 6.0) * (k1v + 2 * k2v + 2 * k3v + k4v) + return y, v + +CKPT_BASE = "/home/pearson/code/ouroboros/poly_seedcull" +DATA_DIR = "/home/pearson/code/ouroboros/data500/gabo_p9" +N_VOCS = 10 +START_OFFSET_MS = 50.0 # match autonomy_score convention +AUTO_N = 3000 # sustained-voc window length +powers = np.arange(16) + + +def correct(x): + fs = len(x); sos = butter(5, 100/(0.5*fs), btype="low", output="sos") + return x - sosfiltfilt(sos, x) + + +def logpsd(x, fs): + f, P = welch(x - x.mean(), fs=fs, nperseg=min(1024, len(x))) + m = f <= 8000 + return np.log(P[m] + 1e-20) + + +# Load N held-out mid-voc segments (uniform across seeds) +wavs = sorted(glob.glob(os.path.join(DATA_DIR, "*.wav")))[:N_VOCS] +segs = [] +for w in wavs: + sr, aud = wavfile.read(w); aud = aud.astype(np.float64) + on_s = np.atleast_2d(np.loadtxt(w.replace(".wav", ".txt")))[0][0] + s = int(round(on_s * sr)) + int(round(START_OFFSET_MS / 1000 * sr)) + seg = aud[s:s + AUTO_N] + if len(seg) == AUTO_N: + segs.append(seg) +dt = 1.0 / sr +print(f"{len(segs)} held-out vocs, L={AUTO_N} (start +{START_OFFSET_MS:.0f} ms from onset), sr={sr}\n", + flush=True) + +header = (f"{'seed':>4} {'n':>3} {'CV%':>6} {'median':>10} {'min':>10} {'max':>10} " + f"{'max/min':>8} {'spec_corr':>10}") +print(header, flush=True) + +for seed in range(8): + cd = f"{CKPT_BASE}/poly_lam1.068_seed{seed}" + if not os.path.isdir(cd): + print(f"{seed:>4} (no dir)"); continue + model, _, _, _ = load_model(cd); model.eval() + tau = float(model.tau) + + factors, corrs = [], [] + for audio in segs: + x = audio[None, :, None] + dy = deriv_approx_dy(x) + xt = torch.from_numpy(x).to(torch.float32).cuda() + dt_t = torch.from_numpy(dy).to(torch.float32).cuda() + with torch.no_grad(): + om, ga, _, wts, _ = model.get_funcs(xt, dt_t.clone(), dt, smoothing=False) + om = om.cpu().numpy()[0, :, 0] + ga = ga.cpu().numpy()[0, :, 0] + co = wts.cpu().numpy()[0] # (L, 16, 16) + + y, v = float(audio[0]), float((tau / dt) * dy[0, 0, 0]) + raw = np.empty(AUTO_N); raw[0] = y + bad = False + for t in range(AUTO_N - 1): + W = co[t].copy(); W[1, 0] = 0; W[0, 1] = 0 + y, v = _poly_rk4_step(y, v, float(om[t]), float(ga[t]), W, powers, 1.0) + if not (np.isfinite(y) and np.isfinite(v)): + bad = True; break + raw[t + 1] = y + if bad: + continue + ca, cr = correct(audio), correct(raw) + s_t, s_a = float(np.std(ca)), float(np.std(cr)) + if s_a < 1e-12: + continue + r = s_t / s_a + factors.append(r) + # spectral correlation AFTER rescaling raw to target RMS (per-voc, like generate_autonomous) + cr_rs = cr * r + cc = float(np.corrcoef(logpsd(ca, 1/dt), logpsd(cr_rs, 1/dt))[0, 1]) + corrs.append(cc) + if not factors: + print(f"{seed:>4} 0 (no usable vocs)"); continue + arr = np.array(factors) + cv = arr.std() / arr.mean() * 100 + print(f"{seed:>4} {len(factors):>3d} {cv:>5.1f}% {np.median(arr):>10.3f} " + f"{arr.min():>10.3f} {arr.max():>10.3f} {arr.max()/arr.min():>7.1f}x " + f"{np.mean(corrs):>10.3f}", flush=True) From 1685e17eefb8fdc70a53411182871b14d4604b09 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Sat, 30 May 2026 17:58:11 -0400 Subject: [PATCH 13/15] small-k rollout in train loss + cold-start raw autonomy selection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Executes docs/small_k_rollout_plan.md §3-4. Adds a short-horizon Euler-step (y, dy/dt) consistency loss term to training (utils.euler_step_k was already in the repo as a relic of mdmarti/ouro_clean; now wired in) and a cold-start mode for the autonomy selection metric. train/train.py: - k_rollout (int, default 0 = off): when >=2, roll the model's predicted second derivative forward k samples per train step and add MSE on (y, dy) at every intermediate step 1..k. k must be << one carrier cycle (~16 samples at sr=40 kHz) so pointwise MSE doesn't reward amplitude collapse via phase drift. - lambda_k (float, default 0.3): weight on the rollout term. - k_rollout_units ('rescaled' | 'physical', default 'rescaled'): time-unit system the rollout is computed in. 'rescaled' uses post-mutation dxdt (= dy/ds), yhat (= d2y/ds^2), and ds = dt/tau, keeping the two MSE terms commensurable with the single-step loss. 'physical' (d2y_phys = yhat/tau^2, dy_phys = dxdt/dt, step = dt) preserves the May-30 §5 matrix runs -- in that mode the chain-rule 1/tau^2 amplification makes the k-step term dominate the gradient on yhat, so the May-30 matrix was effectively a k-step-dominated objective with single-step as a small regularizer. Empirically the 'physical' mode produced better cold-start autonomy despite being unbalanced; default left at 'rescaled' as the well-conditioned starting point for future work. See the docstring for the unit-balance analysis and the project memory's small-k-rollout-matrix entry. - Captures dxdt.clone() before model.forward (which mutates it in-place dxdt *= tau; dxdt /= dt) only when needed by physical-units rollout. - Logs Loss/k_rollout to TensorBoard. train/eval.py:autonomy_score: - cold_start (bool, default False): contract marker that the caller's segments include a silence lead-in (the SILENCE_PAD=2000 / 50 ms convention from make_paired_data_v2.py and scan_seed_amp_coldstart.py). The metric body is unchanged -- integrate_{poly,model}_autonomous already uses audio[0] as IC, which equals near-silence for lead-in windows -- but the breakdown dict now records 'cold_start' and 'rescale' keys for downstream logging. Typically paired with rescale=False (amplitude must contribute to the score). train/model_cv.py:model_cv_lambdas: - Pass-through for k_rollout, lambda_k, k_rollout_units, cold_start_autonomy. - Cull-log metric tag updated to one of {raw, rescaled, cold-start-raw} (was hardcoded 'rescaled'). examples/run_lambda_pipeline.py: - --k-rollout / --lambda-k / --k-rollout-units {rescaled,physical}. - --cold-start-selection (BooleanOptionalAction): switches val/test voc windowing to lead-in + full voc (load_voc_windows_coldstart, trims to common min length), and sets rescale_autonomy=False so the binding deployment criterion is what's selected on. - --silence-pad-samples (default 2000): cold-start lead-in length. - Status line + selected_model.json gain selection metric + k-rollout config + coldstart_raw_test_autonomy / coldstart_amp_pen when cold-start selection is in use. .gitignore: poly_smallk_*/ and poly_smallk_logs/. May-30 §5 matrix on data500/gabo_p{0..9} at lam=1.068, n_seeds=16, cull→3 at ep20 → finish ep50: row cold-start raw test autonomy amp_pen spec_corr pitch_pen baseline -5.48 5.74 0.70 0.44 k2 -4.41 1.34 0.02 3.10 k4 -3.78 3.70 0.56 0.63 k8 (best) -3.37 3.99 0.71 0.09 §6 binding criterion (amp_pen <= 0.5) not met by any row; small-k pressure improves spec_corr/pitch_pen at the margins but the cold-start amplitude direction is structural (Floquet near-zero). Per the plan's failure path, proceeded with the AGC fallback in finchsim. See project memory small-k-rollout-matrix.md for the full outcome + post-matrix follow-ups. Co-Authored-By: Claude Opus 4.7 --- .gitignore | 2 + examples/run_lambda_pipeline.py | 123 +++++++++++++++++++++++++++++--- train/eval.py | 17 ++++- train/model_cv.py | 42 ++++++++--- train/train.py | 69 +++++++++++++++++- 5 files changed, 230 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index c62b631..d869db3 100644 --- a/.gitignore +++ b/.gitignore @@ -185,3 +185,5 @@ poly_attr_*/ poly_pipeline_verify/ poly_pipeline_full/ poly_seedcull/ +poly_smallk_*/ +poly_smallk_logs/ diff --git a/examples/run_lambda_pipeline.py b/examples/run_lambda_pipeline.py index 5fa2200..16f73ef 100644 --- a/examples/run_lambda_pipeline.py +++ b/examples/run_lambda_pipeline.py @@ -14,7 +14,17 @@ Pass --lam <=0 to instead sweep the standard 7-point lambda grid. Pass --cull-frac (e.g. 0.4) to enable SEED CULLING: train all seeds to that fraction of the budget, then finish only the top ---cull-keep by rescaled validation autonomy (~halves the seed-search cost; see docs). +--cull-keep by the configured validation autonomy metric (~halves the seed-search cost; see docs). + +Pass --cold-start-selection to switch the selection metric to COLD-START RAW autonomy: val/test +vocs are loaded with the silence lead-in (silence_pad + voc), the integration IC is at the lead-in +start (≈ silence), and amplitude is left in the score (no rescaling). This matches the situation +finchsim's real-time synthesis faces and is the binding criterion for the deployed pipeline; see +docs/small_k_rollout_plan.md §3.2. Without this flag, selection uses mid-voc rescaled autonomy +(the legacy recipe). + +Pass --k-rollout 2/4/8 (and optionally --lambda-k) to enable the small-k Euler-step consistency +loss during training (docs/small_k_rollout_plan.md §3.1). Default 0 = off (legacy training). Run from the repo root: python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' --out-dir ./poly_pipeline \ @@ -22,6 +32,14 @@ # with seed culling (train 8 seeds, finish the best 2 after 40% of epochs): python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' --n-seeds 8 \ --cull-frac 0.4 --cull-keep 2 + # plan §5 row (cold-start selection, no k-rollout): + python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' \ + --out-dir ./poly_pipeline_baseline --n-epochs 50 --n-seeds 16 --lam 1.068 \ + --cull-frac 0.4 --cull-keep 3 --cold-start-selection + # plan §5 row k=4 (cold-start selection + small-k consistency): + python -m examples.run_lambda_pipeline --data-glob 'data500/gabo_p*' \ + --out-dir ./poly_pipeline_k4 --n-epochs 50 --n-seeds 16 --lam 1.068 \ + --cull-frac 0.4 --cull-keep 3 --cold-start-selection --k-rollout 4 --lambda-k 0.3 """ import argparse @@ -52,6 +70,30 @@ def load_voc_windows(data_dir, n_vocs, start_offset_ms, n): return segs, sr +def load_voc_windows_coldstart(data_dir, n_vocs, silence_pad_samples): + """held-out cold-start windows: `silence_pad_samples` lead-in + full vocalization. + + Matches the convention used by examples/scan_seed_amp_coldstart.py and + make_paired_data_v2.py (50 ms lead-in at 40 kHz = 2000 samples). Per-voc lengths + vary, so we trim to the shortest common length so all segments stack uniformly. + """ + raw = [] + sr = None + for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n_vocs]: + sr, af = wavfile.read(wav) + af = af.astype(np.float64) + onoffs = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt"))) + on_i = int(round(onoffs[0][0] * sr)) + off_i = int(round(onoffs[0][1] * sr)) + start = max(0, on_i - silence_pad_samples) + raw.append(af[start:off_i]) + if not raw: + return [], sr + L = min(len(s) for s in raw) + segs = [s[:L] for s in raw] + return segs, sr + + def main(): p = argparse.ArgumentParser(description=__doc__) p.add_argument("--data-glob", default="data500/gabo_p*") @@ -78,8 +120,29 @@ def main(): p.add_argument("--batch-size", type=int, default=8) p.add_argument("--n-val-vocs", type=int, default=3) p.add_argument("--n-test-vocs", type=int, default=3) - p.add_argument("--auto-n", type=int, default=3000, help="autonomy window length (samples)") - p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--auto-n", type=int, default=3000, help="autonomy window length (samples); " + "ignored when --cold-start-selection is set (cold-start uses lead-in + full voc).") + p.add_argument("--start-offset-ms", type=float, default=50.0, + help="mid-voc start offset (legacy / non-cold-start selection only).") + p.add_argument("--cold-start-selection", action=argparse.BooleanOptionalAction, default=False, + help="select seeds by COLD-START RAW autonomy (lead-in + full voc, IC at " + "silence, no rescaling) instead of the legacy mid-voc rescaled metric. " + "This is the binding criterion for finchsim real-time synthesis. " + "See docs/small_k_rollout_plan.md §3.2.") + p.add_argument("--silence-pad-samples", type=int, default=2000, + help="silence lead-in length in samples for --cold-start-selection. " + "Default 2000 (= 50 ms at 40 kHz; matches make_paired_data_v2.py).") + p.add_argument("--k-rollout", type=int, default=0, + help=">=2 enables the small-k Euler-step rollout consistency loss " + "during training (docs/small_k_rollout_plan.md §3.1). Must be much " + "less than one carrier cycle (~16 samples at sr=40 kHz). Default 0 = off.") + p.add_argument("--lambda-k", type=float, default=0.3, + help="weight on the k-step rollout loss (ignored if --k-rollout < 2).") + p.add_argument("--k-rollout-units", choices=["rescaled", "physical"], default="rescaled", + help="time-unit system for the k-step rollout (see train.train.train). " + "'rescaled' (default) keeps the two MSE terms commensurable and is " + "the recommended choice. 'physical' reproduces the broken-scale " + "May-30 matrix run.") p.add_argument("--seed", type=int, default=1234) p.add_argument("--n-jobs", type=int, default=8) args = p.parse_args() @@ -101,12 +164,31 @@ def main(): dls = get_loaders(np.stack(chunks, 0), num_workers=args.n_jobs, batch_size=args.batch_size, train_size=0.6, cv=True, seed=args.seed, dt=dt) - # held-out autonomy vocalizations (file-level holdout) - val_vocs, _ = load_voc_windows(val_dir, args.n_val_vocs, args.start_offset_ms, args.auto_n) - test_vocs, _ = load_voc_windows(test_dir, args.n_test_vocs, args.start_offset_ms, args.auto_n) + # held-out autonomy vocalizations (file-level holdout). Window depends on the + # selection metric: cold-start uses lead-in + full voc (variable length, trimmed + # to common min); mid-voc uses a fixed-length window starting `start_offset_ms` + # after onset (the legacy recipe). + if args.cold_start_selection: + val_vocs, _ = load_voc_windows_coldstart(val_dir, args.n_val_vocs, + args.silence_pad_samples) + test_vocs, _ = load_voc_windows_coldstart(test_dir, args.n_test_vocs, + args.silence_pad_samples) + # cold-start mode wants amplitude IN the score (see docs/small_k_rollout_plan.md §3.2). + rescale_for_selection = False + else: + val_vocs, _ = load_voc_windows(val_dir, args.n_val_vocs, args.start_offset_ms, args.auto_n) + test_vocs, _ = load_voc_windows(test_dir, args.n_test_vocs, args.start_offset_ms, + args.auto_n) + # legacy mid-voc recipe: amplitude is gauge-fixed at generation, so rescale away here + # and let selection turn on spec/pitch/boundedness. + rescale_for_selection = True + voc_L = len(val_vocs[0]) if val_vocs else 0 print(f"train chunks={len(chunks)} from {len(train_dirs)} shards | " - f"val_vocs={len(val_vocs)} from {os.path.basename(val_dir)} | " - f"test_vocs={len(test_vocs)} from {os.path.basename(test_dir)} | sr={sr}", flush=True) + f"val_vocs={len(val_vocs)} from {os.path.basename(val_dir)} (L={voc_L}) | " + f"test_vocs={len(test_vocs)} from {os.path.basename(test_dir)} | sr={sr} | " + f"selection={'cold-start-raw' if args.cold_start_selection else 'mid-voc-rescaled'} | " + f"k_rollout={args.k_rollout} (lambda_k={args.lambda_k}, " + f"units={args.k_rollout_units})", flush=True) best_model = model_cv_lambdas( dls=dls, dt=dt, n_epochs=args.n_epochs, lr=1e-3, n_kernels=args.n_kernels, @@ -114,24 +196,43 @@ def main(): model_path=out_dir, save_freq=max(args.n_epochs // 5, 1), drive_lowpass_ms=args.drive_lowpass_ms, n_seeds=args.n_seeds, selection="autonomy", val_vocs=val_vocs, test_vocs=test_vocs, - keep_const=args.keep_const, rescale_autonomy=True, + keep_const=args.keep_const, rescale_autonomy=rescale_for_selection, lambdas=None if args.lam <= 0 else [args.lam], cull_frac=args.cull_frac, cull_keep=args.cull_keep, + k_rollout=args.k_rollout, lambda_k=args.lambda_k, + k_rollout_units=args.k_rollout_units, + cold_start_autonomy=args.cold_start_selection, ) # DEPLOYED generation: rescaled autonomous reconstruction of the held-out test vocs. - # (Amplitude is a free gauge fixed at generation; selection above already used rescaled autonomy.) + # (Amplitude is a free gauge fixed at generation by generate_autonomous.) When the + # selection metric was cold-start raw, also report the cold-start raw test number -- + # this is the binding metric for finchsim deployment (docs/small_k_rollout_plan.md §6). if test_vocs: # authoritative selected lambda from the selector (handles fixed-lambda, swept-grid, and culled) sel_lambda = float(getattr(best_model, "_selected_lambda", args.lam)) - score, _, bd = autonomy_score(best_model, test_vocs, dt, rescale=True) + score, _, bd = autonomy_score(best_model, test_vocs, dt, rescale=True, + cold_start=args.cold_start_selection) recon = generate_autonomous(best_model, test_vocs[0], dt, rescale=True) wav = (recon / (np.abs(recon).max() + 1e-12) * 0.95 * 32767).astype(np.int16) wavfile.write(os.path.join(out_dir, "selected_autonomous_recon.wav"), int(round(1 / dt)), wav) manifest = {"selected_lambda": sel_lambda, "n_seeds": int(args.n_seeds), "keep_const": bool(args.keep_const), "drive_lowpass_ms": float(args.drive_lowpass_ms), + "cold_start_selection": bool(args.cold_start_selection), + "k_rollout": int(args.k_rollout), "lambda_k": float(args.lambda_k), + "k_rollout_units": str(args.k_rollout_units), "rescaled_test_autonomy": score, "spec_corr": bd["spec_corr"], "pitch_pen": bd["pitch_pen"], "bounded_frac": bd["bounded_frac"]} + if args.cold_start_selection: + # also record the cold-start RAW score: the no-rescale number that selection used, + # the metric finchsim deployment actually has to clear. + raw_score, _, raw_bd = autonomy_score(best_model, test_vocs, dt, rescale=False, + cold_start=True) + manifest["coldstart_raw_test_autonomy"] = raw_score + manifest["coldstart_amp_pen"] = raw_bd["amp_pen"] + print(f"COLD-START RAW test autonomy = {raw_score:+.3f} (amp_pen={raw_bd['amp_pen']:.3f}, " + f"spec={raw_bd['spec_corr']:.2f}, pitch_pen={raw_bd['pitch_pen']:.2f}, " + f"bounded={raw_bd['bounded_frac']:.2f})", flush=True) with open(os.path.join(out_dir, "selected_model.json"), "w") as f: json.dump(manifest, f, indent=2) print(f"deployed (rescaled) test autonomy = {score:+.3f} (spec={bd['spec_corr']:.2f}, " diff --git a/train/eval.py b/train/eval.py index 0964e24..dc934e3 100644 --- a/train/eval.py +++ b/train/eval.py @@ -448,6 +448,7 @@ def autonomy_score( diverge_score: float = -5.0, method: str = "rk4", rescale: bool = False, + cold_start: bool = False, ) -> tuple: """ validation metric for AUTONOMOUS reconstruction quality (model-selection criterion). @@ -469,11 +470,23 @@ def autonomy_score( the genuinely-constrained quantities (spectral shape + pitch + boundedness). A divergent rollout is still detected on the RAW output and gets `diverge_score` (collapse is not rescaled away). + If `cold_start=True`, the caller is asserting that each `segments[i]` already includes a silence + lead-in (the SILENCE_PAD=2000-sample / 50-ms convention used by + `examples/scan_seed_amp_coldstart.py` and `make_paired_data_v2.py`). The integration IC is the + first sample of the segment (already the behavior of `integrate_{poly,model}_autonomous`), + which then equals near-silence and exercises the ignition path -- the situation finchsim's + real-time synthesis actually faces. Mechanically the metric is unchanged; this flag exists + so callers can request the cold-start scoring contract and so the chosen mode is recorded in + the breakdown. Typically paired with `rescale=False` (amplitude must contribute to the score, + because a single shipped rescale constant cannot fix voc-variable cold-start amplitude). + See docs/small_k_rollout_plan.md §3.2. + returns ----- - mean score over segments - per-segment scores (list) - - breakdown dict (mean spectral corr, amp penalty, pitch penalty, bounded fraction) + - breakdown dict (mean spectral corr, amp penalty, pitch penalty, bounded fraction, + plus the `cold_start` flag for downstream logging) """ fs = 1.0 / dt @@ -519,6 +532,8 @@ def _peak(x): "amp_pen": float(np.nanmean(amps)) if len(amps) else float("nan"), "pitch_pen": float(np.nanmean(pits)) if len(pits) else float("nan"), "bounded_frac": float(np.mean(bounded)) if len(bounded) else 0.0, + "cold_start": bool(cold_start), + "rescale": bool(rescale), } return float(np.mean(scores)), scores, breakdown diff --git a/train/model_cv.py b/train/model_cv.py index f021e19..ba6b8db 100644 --- a/train/model_cv.py +++ b/train/model_cv.py @@ -42,6 +42,10 @@ def model_cv_lambdas( lambdas: list = None, cull_frac: float = 0.0, cull_keep: int = 2, + k_rollout: int = 0, + lambda_k: float = 0.3, + k_rollout_units: str = "rescaled", + cold_start_autonomy: bool = False, ) -> torch.nn.Module: """ This function trains models and cross-validates across regularization strengths. @@ -52,9 +56,21 @@ def model_cv_lambdas( Seed culling (cull_frac>0, selection='autonomy'): the autonomous-reconstruction quality is set by the random seed and its ranking settles well before R^2 plateaus, so instead of training every - run fully, train all runs to `cull_frac` of n_epochs, rank by rescaled validation autonomy, and - finish only the top `cull_keep`. The best finished run is selected. Roughly halves the seed search - (see docs/autonomous_amplitude.md). With cull_frac=0 (default) every run is trained fully. + run fully, train all runs to `cull_frac` of n_epochs, rank by the configured validation + autonomy metric (rescale_autonomy/cold_start_autonomy), and finish only the top `cull_keep`. + The best finished run is selected. Roughly halves the seed search (see + docs/autonomous_amplitude.md). With cull_frac=0 (default) every run is trained fully. + + Small-k rollout consistency (k_rollout>=2): add a short-horizon Euler-step (y, dy/dt) + consistency term to the train objective, weight lambda_k (see docs/small_k_rollout_plan.md + and train.train.train). Default 0 disables it and preserves the legacy single-step + objective. k must be << one carrier cycle so phase drift doesn't pathologize the MSE. + + Cold-start selection (cold_start_autonomy=True): pass cold_start=True through to + autonomy_score, asserting that val_vocs/test_vocs already include the silence lead-in + (see docs/small_k_rollout_plan.md §3.2). The caller (e.g. run_lambda_pipeline) is + responsible for windowing -- this is purely a pass-through. Typically paired with + rescale_autonomy=False (amplitude variability is what we are trying to fix). inputs ----- @@ -120,7 +136,9 @@ def _train_to(lam, seed, target): train(model, opt, loss_fn=lambda y, yhat: sse(yhat, y, reduction="mean"), loaders=dls, scheduler=sched, nEpochs=target, val_freq=1, runDir=run_dir, dt=dt, vis_freq=0, smoothing=False, reg_weights=True, start_epoch=start_epoch, - save_freq=save_freq, model_info=model_info) + save_freq=save_freq, model_info=model_info, + k_rollout=k_rollout, lambda_k=lambda_k, + k_rollout_units=k_rollout_units) save_model(model, opt, os.path.join(run_dir, f"checkpoint_{target}.tar"), n_layers=n_layers, d_state=d_state, expand_factor=expand_factor, d_conv=d_conv) return model, run_dir @@ -130,7 +148,8 @@ def _score(model): with torch.no_grad(): (_, vr2), _, _ = eval_model_error(dls, model, dt=dt, comparison="val") if selection == "autonomy": - va, _, bd = autonomy_score(model, val_vocs, dt, rescale=rescale_autonomy) + va, _, bd = autonomy_score(model, val_vocs, dt, rescale=rescale_autonomy, + cold_start=cold_start_autonomy) else: va, bd = np.nan, None return vr2, va, bd @@ -139,12 +158,16 @@ def _score(model): do_cull = (0.0 < cull_frac < 1.0) and selection == "autonomy" and cull_keep < n_jobs records = [] if do_cull: - # SEED CULLING: train every run to a cull epoch, rank by rescaled val autonomy, then finish - # only the top `cull_keep`. The autonomy ranking is ~settled well before R^2 plateaus + # SEED CULLING: train every run to a cull epoch, rank by the configured val autonomy + # metric (see rescale_autonomy / cold_start_autonomy), then finish only the top + # `cull_keep`. The autonomy ranking is ~settled well before R^2 plateaus # (see docs/autonomous_amplitude.md), so this finds the best seed at a fraction of the cost. cull_epoch = max(1, int(round(cull_frac * n_epochs))) + metric_tag = ("cold-start-raw" if cold_start_autonomy + else ("rescaled" if rescale_autonomy else "raw")) print(f"\n=== seed culling: train all {n_jobs} run(s) to epoch {cull_epoch} " - f"({cull_frac:.0%} of {n_epochs}), then finish top {cull_keep} by rescaled val autonomy ===", + f"({cull_frac:.0%} of {n_epochs}), then finish top {cull_keep} by " + f"{metric_tag} val autonomy ===", flush=True) ranked = [] for lam in lambdas: @@ -215,7 +238,8 @@ def _score(model): (_, test_r2), _, _ = eval_model_error(dls, best_model, dt=dt, comparison="test") test_auto = np.nan if selection == "autonomy" and test_vocs: - test_auto, _, _ = autonomy_score(best_model, test_vocs, dt, rescale=rescale_autonomy) + test_auto, _, _ = autonomy_score(best_model, test_vocs, dt, rescale=rescale_autonomy, + cold_start=cold_start_autonomy) print(f"BEST lambda={best_lambda:.3f}: test R2={test_r2:.4f} test autonomy={test_auto:+.4f}", flush=True) best_model._selected_lambda = best_lambda # authoritative selected lambda for callers/manifests diff --git a/train/train.py b/train/train.py index 2f2482b..a3d75c7 100644 --- a/train/train.py +++ b/train/train.py @@ -2,7 +2,7 @@ import torch import numpy as np from tqdm import tqdm -from utils import sst, sse +from utils import sst, sse, euler_step_k import matplotlib.pyplot as plt import os import glob @@ -182,6 +182,9 @@ def train( start_epoch=0, model_info={}, save_freq=0, + k_rollout=0, + lambda_k=0.3, + k_rollout_units="rescaled", ) -> Tuple[ list[float], list[Tuple[int, float, float]], nn.Module, torch.optim.Optimizer ]: @@ -207,6 +210,31 @@ def train( this might be higher - model_info: dictionary of model structure specification. used for saving models - save_freq: how often (in epochs) to save your model + - k_rollout: if >=2, add a small-k Euler-step consistency loss to the train + objective (see `utils.euler_step_k`). The model's predicted second + derivative is rolled forward `k_rollout` samples and required to + reproduce the ground-truth (y, dy/dt) trajectory at every intermediate + step. k must be << one carrier cycle (~16 samples at sr=40 kHz) so + pointwise MSE doesn't reward amplitude collapse via phase drift. + See docs/small_k_rollout_plan.md. Default 0 disables the term and + preserves the legacy single-step ẍ-MSE training objective. + - lambda_k: relative weight on the k-step rollout loss + (`total = single_step + lambda_k * k_rollout`). Ignored if + k_rollout<2. + - k_rollout_units: "rescaled" (default) or "physical". Selects the + time-unit system the k-step rollout is computed in: + * "rescaled": state = (y, dy/ds), step = ds = dt/τ. y and dy/ds + have comparable scales (dy/ds ~ τ·2π·f·y; with τ=1/40000 and + f≈2.5 kHz, ratio ~0.4), so MSE(y)+MSE(dy/ds) is + well-conditioned and commensurable with the single-step ẍ MSE + (also in rescaled units). No τ² division, no dxdt clone -- + the model's post-mutation dxdt is already dy/ds and yhat is + already d²y/ds². This is the recommended choice. + * "physical": state = (y, dy/dt), step = dt. dy/dt is ~10⁴× + larger than y, so MSE(dy/dt) dominates MSE(y) by ~10⁸ and the + whole k-step term is on a different scale from the single-step + ẍ MSE. Useful only for reproducing the May-30 matrix runs in + docs/small_k_rollout_plan.md §5 (which used this mode). returns ---- @@ -236,6 +264,13 @@ def train( dx2dt2.to("cuda").to(torch.float32) / (dt**2) * model.tau**2 ) # rescale dx2, rather than model output + # model.forward mutates dxdt in-place (dxdt *= tau; dxdt /= dt -> dy/ds). + # The physical-units k-rollout needs the ORIGINAL per-sample dxdt, so clone + # before the model call. The rescaled-units k-rollout uses the post-mutation + # dxdt (= dy/ds) directly, so no clone is needed. + need_persamp = (k_rollout >= 2) and (k_rollout_units == "physical") + dxdt_persamp = dxdt.clone() if need_persamp else None + dx2hat, weights = model(x, dxdt, dt, smoothing) # state: B x L x SD yhat = dx2hat @@ -327,14 +362,44 @@ def px(x): # we take mean over samples to match the loss fn we use (MSE, with mean over samples) total_loss = train_loss + penalty + if k_rollout >= 2: + # Small-k Euler-step rollout consistency (see docs/small_k_rollout_plan.md). + # euler_step_k returns the ground-truth (y, dy) and the recurrently-stepped + # predictions at every intermediate step 1..k, stacked on the last dim; + # loss_fn is applied to both. Gradient flows through the model's d2y output. + if k_rollout_units == "rescaled": + # rescaled time s = t/τ: dxdt (post-mutation) = dy/ds, yhat = d²y/ds². + # Step size = ds = dt/τ. Two MSE terms are commensurable with y. + ds = dt / model.tau + (y_gt, y_pred), (dy_gt, dy_pred) = euler_step_k( + x, dxdt, yhat, ds, k=k_rollout + ) + elif k_rollout_units == "physical": + # physical time t: dy/dt = dxdt_persamp/dt, d²y/dt² = yhat/τ². + # MSE(dy/dt) dominates MSE(y) by ~10⁸; lambda_k absorbs this. + d2y_phys = yhat / (model.tau ** 2) + dy_phys = dxdt_persamp / dt + (y_gt, y_pred), (dy_gt, dy_pred) = euler_step_k( + x, dy_phys, d2y_phys, dt, k=k_rollout + ) + else: + raise ValueError( + f"k_rollout_units must be 'rescaled' or 'physical', got " + f"{k_rollout_units!r}" + ) + k_loss = loss_fn(y_gt, y_pred) + loss_fn(dy_gt, dy_pred) + total_loss = total_loss + lambda_k * k_loss + total_loss.backward() optimizer.step() - + train_losses.append(train_loss.item()) # we should probably be adding val loss here too...ugh writer.add_scalar("Loss/train", train_loss.item(), idx) if reg_weights: writer.add_scalar("Penalty/train", penalty.item(), idx) + if k_rollout >= 2: + writer.add_scalar("Loss/k_rollout", k_loss.item(), idx) if epoch % val_freq == 0: model.eval() From 92de48ba15e0930c8967643811a3f71564f2da05 Mon Sep 17 00:00:00 2001 From: John Pearson Date: Tue, 2 Jun 2026 08:23:49 -0400 Subject: [PATCH 14/15] real-finch training: fix int16 val/test loader normalization + add staging/eval/plot scripts `get_audio_training` divides int16 audio by 32768 to produce float in [-1, 1], but `load_voc_windows` and `load_voc_windows_coldstart` were astyping to float64 without the divide. On synthetic gabo this was masked because those wavs are float32; on real finch int16 data the model was trained on [-0.02, 0.02] while the autonomy scorer was integrating on [-600, 600], so the polynomial closure overflowed at step 0 and `autonomy_score` returned the -5.0 diverge_score on every val/test voc. Replicate the train loader's int16 normalization in both val/test loaders. Also adds two pipeline knobs and a seed-init fix: - `--max-vocs-per-shard` (run_lambda_pipeline): cap training chunks per shard so multi-thousand-voc real datasets don't blow the epoch budget. - `--save-freq` (run_lambda_pipeline): expose checkpoint cadence (needed for per-checkpoint diagnostics of when integration becomes stable). - `torch.manual_seed`/`np.random.seed` in train_arneodo_big::main so `--seed` actually varies model init across runs. New scripts under scripts/: - stage_finch_blk445_syllC.py: symlink-only staging of blk445 syllable_C cleaned wavs + annotations into per-day shards matching the pipeline's `{stem}.wav`/`{stem}.txt` glob convention. - eval_checkpoints.py: re-evaluate train R^2, val R^2, and mid-voc rescaled autonomy at every saved checkpoint of a single run. - instrument_rollout.py: per-step RK4 of `integrate_poly_autonomous` with state magnitudes printed at a chosen stride; this is what surfaced the loader bug. - resynth_recon.py: synthesize and write target+recon WAVs for a given checkpoint, using the fixed loader. - plot_recon_grid.py / plot_coldstart_grid.py: 3x4 (waveform, spectrogram) x (target, recon) comparison plots for warm-start and cold-start synthesis. - arneodo_seed_pool.sh: sequential 16-seed Arneodo pool driver; trains each seed, evaluates the last checkpoint, picks the best by mid-voc rescaled autonomy, and writes the winner's recon WAVs + comparison plots. Co-Authored-By: Claude Opus 4.7 --- examples/run_lambda_pipeline.py | 12 ++- examples/train_arneodo_big.py | 3 + scripts/arneodo_seed_pool.sh | 92 ++++++++++++++++++++ scripts/eval_checkpoints.py | 111 ++++++++++++++++++++++++ scripts/instrument_rollout.py | 118 ++++++++++++++++++++++++++ scripts/plot_coldstart_grid.py | 127 ++++++++++++++++++++++++++++ scripts/plot_recon_grid.py | 118 ++++++++++++++++++++++++++ scripts/resynth_recon.py | 78 +++++++++++++++++ scripts/stage_finch_blk445_syllC.py | 83 ++++++++++++++++++ 9 files changed, 740 insertions(+), 2 deletions(-) create mode 100755 scripts/arneodo_seed_pool.sh create mode 100644 scripts/eval_checkpoints.py create mode 100644 scripts/instrument_rollout.py create mode 100644 scripts/plot_coldstart_grid.py create mode 100644 scripts/plot_recon_grid.py create mode 100644 scripts/resynth_recon.py create mode 100644 scripts/stage_finch_blk445_syllC.py diff --git a/examples/run_lambda_pipeline.py b/examples/run_lambda_pipeline.py index 16f73ef..55c53e7 100644 --- a/examples/run_lambda_pipeline.py +++ b/examples/run_lambda_pipeline.py @@ -61,6 +61,8 @@ def load_voc_windows(data_dir, n_vocs, start_offset_ms, n): segs = [] for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n_vocs]: sr, af = wavfile.read(wav) + if af.dtype == np.int16: + af = af / -np.iinfo(af.dtype).min af = af.astype(np.float64) onoffs = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt"))) s = int(onoffs[0][0] * sr) + int(start_offset_ms / 1e3 * sr) @@ -81,6 +83,8 @@ def load_voc_windows_coldstart(data_dir, n_vocs, silence_pad_samples): sr = None for wav in sorted(glob.glob(os.path.join(data_dir, "*.wav")))[:n_vocs]: sr, af = wavfile.read(wav) + if af.dtype == np.int16: + af = af / -np.iinfo(af.dtype).min af = af.astype(np.float64) onoffs = np.atleast_2d(np.loadtxt(wav.replace(".wav", ".txt"))) on_i = int(round(onoffs[0][0] * sr)) @@ -118,6 +122,10 @@ def main(): p.add_argument("--n-kernels", type=int, default=15) p.add_argument("--context-len", type=float, default=0.1) p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--max-vocs-per-shard", type=int, default=0, + help="cap training chunks per shard; 0 = legacy 100000 / n_train_dirs.") + p.add_argument("--save-freq", type=int, default=0, + help="checkpoint every N epochs; 0 = legacy max(n_epochs//5, 1).") p.add_argument("--n-val-vocs", type=int, default=3) p.add_argument("--n-test-vocs", type=int, default=3) p.add_argument("--auto-n", type=int, default=3000, help="autonomy window length (samples); " @@ -155,7 +163,7 @@ def main(): # training chunks from the train shards chunks, sr = [], None - per = 100000 // max(1, len(train_dirs)) + per = args.max_vocs_per_shard if args.max_vocs_per_shard > 0 else 100000 // max(1, len(train_dirs)) for d in train_dirs: audio, sr = get_segmented_audio(d, d, max_vocs=per, context_len=args.context_len, seed=args.seed, training=True, extend=True, shuffle_order=True) @@ -193,7 +201,7 @@ def main(): best_model = model_cv_lambdas( dls=dls, dt=dt, n_epochs=args.n_epochs, lr=1e-3, n_kernels=args.n_kernels, expand_factor=10, n_layers=3, d_state=args.d_state, d_conv=4, tau=dt, - model_path=out_dir, save_freq=max(args.n_epochs // 5, 1), + model_path=out_dir, save_freq=args.save_freq if args.save_freq > 0 else max(args.n_epochs // 5, 1), drive_lowpass_ms=args.drive_lowpass_ms, n_seeds=args.n_seeds, selection="autonomy", val_vocs=val_vocs, test_vocs=test_vocs, keep_const=args.keep_const, rescale_autonomy=rescale_for_selection, diff --git a/examples/train_arneodo_big.py b/examples/train_arneodo_big.py index 9616a51..cb8f4af 100644 --- a/examples/train_arneodo_big.py +++ b/examples/train_arneodo_big.py @@ -70,6 +70,9 @@ def main(): p.add_argument("--n-jobs", type=int, default=4) args = p.parse_args() + torch.manual_seed(args.seed) + np.random.seed(args.seed) + out_dir = os.path.abspath(args.out_dir) run_dir = os.path.join(out_dir, "arneodo") os.makedirs(run_dir, exist_ok=True) diff --git a/scripts/arneodo_seed_pool.sh b/scripts/arneodo_seed_pool.sh new file mode 100755 index 0000000..d498a9a --- /dev/null +++ b/scripts/arneodo_seed_pool.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash +# Drive a sequential Arneodo seed pool: train each seed for `EPOCHS` epochs, +# then run the checkpoint eval to pick the best seed by mid-voc rescaled +# autonomy on val (day85). After all seeds finish, write a comparison table +# + generate recon WAV + plot for the winner. +set -e + +cd /home/pearson/code/ouroboros + +POOL_DIR="${POOL_DIR:-./finch_blk445_syllC_arneodo_pool}" +DATA_GLOB="${DATA_GLOB:-/home/pearson/ouroboros_data/blk445_syllC/day*}" +TRAIN_GLOB="${TRAIN_GLOB:-/home/pearson/ouroboros_data/blk445_syllC/day84}" +N_SEEDS="${N_SEEDS:-16}" +EPOCHS="${EPOCHS:-100}" +SEG="${SEG:-5}" +MAX_VOCS="${MAX_VOCS:-1000}" + +mkdir -p "${POOL_DIR}" +SUMMARY="${POOL_DIR}/pool_summary.csv" +echo "seed,train_r2,val_r2,rescaled_auto,spec_corr,pitch_pen,bounded" > "${SUMMARY}" + +for s in $(seq 0 $((N_SEEDS - 1))); do + SEED_DIR="${POOL_DIR}/seed${s}" + mkdir -p "${SEED_DIR}" + echo "=== seed ${s} train start at $(date +%H:%M:%S) ===" + .venv/bin/python -m examples.train_arneodo_big \ + --data-glob "${TRAIN_GLOB}" \ + --out-dir "${SEED_DIR}" \ + --max-vocs ${MAX_VOCS} \ + --epochs ${EPOCHS} --seg ${SEG} \ + --n-layers 3 --d-state 4 --d-conv 4 --expand-factor 10 \ + --drive-lowpass-ms 1.0 \ + --batch-size 8 --context-len 0.1 \ + --seed ${s} --n-jobs 8 \ + --target-r2 0.99 \ + > "${SEED_DIR}/train.log" 2>&1 + + echo "=== seed ${s} eval start at $(date +%H:%M:%S) ===" + # eval the latest checkpoint only (the script will print all available) + EVAL_OUT="${SEED_DIR}/eval.txt" + .venv/bin/python scripts/eval_checkpoints.py \ + --run-dir "${SEED_DIR}/arneodo" \ + --data-glob "${DATA_GLOB}" 2>&1 \ + | grep -vE "RuntimeWarning|nanmean|loading from|model tau|^\s*$|Train r2|train r2|val r2|^ *spec_corr|power_mat|return vv" \ + > "${EVAL_OUT}" + + # take the LAST line of the eval table as the final-checkpoint number + LAST=$(grep -E "^ *[0-9]+ " "${EVAL_OUT}" | tail -1) + echo "seed ${s}: ${LAST}" + # parse columns: epoch train_r2 val_r2 rescaled_auto spec amp_pen pitch_pen bounded + read epoch train_r2 val_r2 auto spec amp pitch bounded <<< "${LAST}" + echo "${s},${train_r2},${val_r2},${auto},${spec},${pitch},${bounded}" >> "${SUMMARY}" +done + +echo "=== all ${N_SEEDS} seeds done at $(date +%H:%M:%S) ===" +echo "summary:" +cat "${SUMMARY}" + +# pick best seed by rescaled_auto (column 4 in summary) +BEST_SEED=$(tail -n +2 "${SUMMARY}" | sort -t, -k4 -gr | head -1 | cut -d, -f1) +echo "=== best seed by autonomy: ${BEST_SEED} ===" + +# find the latest checkpoint of the best seed +BEST_CKPT=$(ls "${POOL_DIR}/seed${BEST_SEED}/arneodo/"*.tar 2>/dev/null \ + | xargs -n1 basename \ + | sed 's/checkpoint_\([0-9]*\)\.tar/\1/' \ + | sort -n | tail -1) +BEST_CKPT_PATH="${POOL_DIR}/seed${BEST_SEED}/arneodo/checkpoint_${BEST_CKPT}.tar" +echo "best ckpt: ${BEST_CKPT_PATH}" + +echo "=== writing recon WAVs + comparison plots for best seed ===" +.venv/bin/python scripts/resynth_recon.py \ + --ckpt "${BEST_CKPT_PATH}" \ + --data-glob "${DATA_GLOB}" \ + --out-dir "${POOL_DIR}/best_recons" \ + --n-vocs 3 2>&1 | grep -vE "RuntimeWarning|power_mat|nanmean|return vv" + +.venv/bin/python scripts/plot_recon_grid.py \ + --ckpt "${BEST_CKPT_PATH}" \ + --data-glob "${DATA_GLOB}" \ + --out "${POOL_DIR}/best_recon_comparison.png" 2>&1 | grep -vE "RuntimeWarning|power_mat|nanmean|return vv" + +.venv/bin/python scripts/plot_coldstart_grid.py \ + --ckpt "${BEST_CKPT_PATH}" \ + --data-glob "${DATA_GLOB}" \ + --out "${POOL_DIR}/best_coldstart_comparison.png" 2>&1 | grep -vE "RuntimeWarning|power_mat|nanmean|return vv" + +echo "=== POOL DONE ===" +echo "summary: ${SUMMARY}" +echo "best seed: ${BEST_SEED} ckpt: ${BEST_CKPT_PATH}" +echo "plots: ${POOL_DIR}/best_recon_comparison.png" +echo " ${POOL_DIR}/best_coldstart_comparison.png" diff --git a/scripts/eval_checkpoints.py b/scripts/eval_checkpoints.py new file mode 100644 index 0000000..56965e4 --- /dev/null +++ b/scripts/eval_checkpoints.py @@ -0,0 +1,111 @@ +"""Evaluate every checkpoint in a single-seed run to chart when integration stabilises. + +Loads each `checkpoint_{epoch}.tar` from a poly_lam{lam}_seed{s} directory, computes: + - train R^2 (one-step ẍ-prediction) + - val R^2 + - mid-voc rescaled autonomy (the legacy selection metric) + +Prints a per-epoch table so we can see at what R^2 the autonomous integration first +becomes finite (escapes the diverge_score=-5.0 floor) on real-finch data. + +Usage: + .venv/bin/python scripts/eval_checkpoints.py \\ + --run-dir ./finch_blk445_syllC_diag/poly_lam1.068_seed0 \\ + --data-glob '/home/pearson/ouroboros_data/blk445_syllC/day*' +""" + +import argparse +import glob +import os +import re +import shutil +import tempfile + +import numpy as np +import torch + +from data.load_data import get_segmented_audio +from data.data_utils import get_loaders +from train.train import load_model +from train.eval import autonomy_score, eval_model_error +from examples.run_lambda_pipeline import load_voc_windows + + +def list_checkpoints(run_dir): + pat = re.compile(r"checkpoint_(\d+)\.tar$") + ckpts = [] + for f in glob.glob(os.path.join(run_dir, "checkpoint_*.tar")): + m = pat.search(os.path.basename(f)) + if m: + ckpts.append((int(m.group(1)), f)) + ckpts.sort() + return ckpts + + +def load_specific(ckpt_path): + """Load just this checkpoint by hiding the others in a tempdir.""" + tmp = tempfile.mkdtemp(prefix="ckpt_eval_") + link = os.path.join(tmp, os.path.basename(ckpt_path)) + os.symlink(os.path.abspath(ckpt_path), link) + try: + model, _, _, epoch = load_model(tmp) + finally: + shutil.rmtree(tmp, ignore_errors=True) + return model, epoch + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--run-dir", required=True, + help="poly_lam{lam}_seed{s} directory containing checkpoint_*.tar") + p.add_argument("--data-glob", required=True, + help="same glob the pipeline was launched with") + p.add_argument("--max-vocs-per-shard", type=int, default=1000) + p.add_argument("--context-len", type=float, default=0.1) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--n-val-vocs", type=int, default=3) + p.add_argument("--auto-n", type=int, default=3000) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--seed", type=int, default=1234) + p.add_argument("--n-jobs", type=int, default=8) + args = p.parse_args() + + # rebuild the same data state the pipeline used + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + assert len(dirs) >= 3 + train_dirs, val_dir, _ = dirs[:-2], dirs[-2], dirs[-1] + chunks, sr = [], None + for d in train_dirs: + audio, sr = get_segmented_audio(d, d, max_vocs=args.max_vocs_per_shard, + context_len=args.context_len, seed=args.seed, + training=True, extend=True, shuffle_order=True) + chunks += audio + dt = 1.0 / sr + dls = get_loaders(np.stack(chunks, 0), num_workers=args.n_jobs, + batch_size=args.batch_size, train_size=0.6, cv=True, + seed=args.seed, dt=dt) + val_vocs, _ = load_voc_windows(val_dir, args.n_val_vocs, + args.start_offset_ms, args.auto_n) + print(f"# train chunks={len(chunks)} sr={sr} val_vocs={len(val_vocs)} " + f"L={len(val_vocs[0]) if val_vocs else 0}") + + ckpts = list_checkpoints(args.run_dir) + print(f"# {len(ckpts)} checkpoint(s) in {args.run_dir}") + print(f"{'epoch':>6} {'train_r2':>10} {'val_r2':>10} {'rescaled_auto':>14} " + f"{'spec':>6} {'amp_pen':>8} {'pitch_pen':>10} {'bounded':>8}") + for epoch, path in ckpts: + model, _ = load_specific(path) + model.eval() + with torch.no_grad(): + (_, train_r2), _, _ = eval_model_error(dls, model, dt=dt, comparison="train") + (_, val_r2), _, _ = eval_model_error(dls, model, dt=dt, comparison="val") + auto, _, bd = autonomy_score(model, val_vocs, dt, rescale=True, cold_start=False) + print(f"{epoch:>6d} {train_r2:>10.4f} {val_r2:>10.4f} {auto:>14.4f} " + f"{bd['spec_corr']:>6.2f} {bd['amp_pen']:>8.3f} " + f"{bd['pitch_pen']:>10.3f} {bd['bounded_frac']:>8.2f}", flush=True) + del model + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/scripts/instrument_rollout.py b/scripts/instrument_rollout.py new file mode 100644 index 0000000..5bb645b --- /dev/null +++ b/scripts/instrument_rollout.py @@ -0,0 +1,118 @@ +"""Step-by-step RK4 rollout of integrate_poly_autonomous with per-step diagnostics. + +Lets us see *where* and *how* the autonomous integration blows up: is it step 1 +(initial-condition overflow), is it slow exponential growth, does it survive +some steps then suddenly hit nan? + +Prints the state magnitudes (x, x', kernel-output, dx') at a chosen stride. +""" + +import argparse +import glob +import os +import shutil +import tempfile + +import numpy as np +import torch + +from data.load_data import get_segmented_audio # noqa +from train.train import load_model +from train.eval import deriv_approx_dy, correct +from examples.run_lambda_pipeline import load_voc_windows + + +def load_specific(ckpt_path): + tmp = tempfile.mkdtemp(prefix="ckpt_eval_") + link = os.path.join(tmp, os.path.basename(ckpt_path)) + os.symlink(os.path.abspath(ckpt_path), link) + try: + model, _, _, _ = load_model(tmp) + finally: + shutil.rmtree(tmp, ignore_errors=True) + return model + + +def instrument_one(model, audio, dt, stride=50, max_steps=None): + """Mirror integrate_poly_autonomous's noise_sd>0 path (per-sample RK4) with logging.""" + L = len(audio) + audio_3d = audio[None, :, None] + dy = deriv_approx_dy(audio_3d) + audio_t = torch.from_numpy(audio_3d).to(torch.float32).to("cuda") + dy_t = torch.from_numpy(dy).to(torch.float32).to("cuda") + + with torch.no_grad(): + omega, gamma, _, weights, _ = model.get_funcs(audio_t, dy_t, dt) + omega = omega.detach().cpu().numpy().squeeze() + gamma = gamma.detach().cpu().numpy().squeeze() + weights = weights.detach().cpu().numpy() + _, _, P, P2 = weights.shape + ww = weights.reshape(L, 1, 1, P, P2) + + kernel = model.kernel + + x0 = float(audio[0]) + xp0 = (model.tau / dt) * float(dy[0, 0, 0]) + print(f"# IC: x={x0:.4e} xp={xp0:.4e} tau={model.tau:.4e} dt={dt:.4e} L={L}") + print(f"# drives at t=0: omega={omega[0]:.4e} gamma={gamma[0]:.4e} " + f"|w|max={np.abs(weights[0]).max():.4e}") + print(f"{'step':>6} {'x':>14} {'xp':>14} {'omega':>10} {'gamma':>10} " + f"{'kern':>14} {'dxp':>14}") + x, xp = x0, xp0 + end = L - 1 if max_steps is None else min(max_steps, L - 1) + for k in range(end): + om, ga, wk = omega[k], gamma[k], ww[k] + + def f(xx, vv): + kern = float(kernel.forward_given_weights_numpy(np.array([[[xx, vv]]]), wk).squeeze()) + return vv, -(om ** 2) * xx - ga * vv - kern, kern + + k1x, k1v, kern1 = f(x, xp) + k2x, k2v, _ = f(x + 0.5 * k1x, xp + 0.5 * k1v) + k3x, k3v, _ = f(x + 0.5 * k2x, xp + 0.5 * k2v) + k4x, k4v, _ = f(x + k3x, xp + k3v) + x_new = x + (k1x + 2 * k2x + 2 * k3x + k4x) / 6 + xp_new = xp + (k1v + 2 * k2v + 2 * k3v + k4v) / 6 + + if (k % stride == 0) or (not np.isfinite(x_new)) or (not np.isfinite(xp_new)) or k < 5: + print(f"{k:>6d} {x:>14.4e} {xp:>14.4e} {om:>10.3e} {ga:>10.3e} " + f"{kern1:>14.4e} {k1v:>14.4e}") + if (not np.isfinite(x_new)) or (not np.isfinite(xp_new)): + print(f"# NON-FINITE at step {k+1}: x_new={x_new} xp_new={xp_new}") + return k + 1 + x, xp = x_new, xp_new + print(f"# completed {end} steps, final x={x:.4e}, xp={xp:.4e}") + return end + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--ckpt", required=True, + help="path to checkpoint_N.tar") + p.add_argument("--data-glob", required=True) + p.add_argument("--n-val-vocs", type=int, default=3) + p.add_argument("--auto-n", type=int, default=3000) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--stride", type=int, default=50) + p.add_argument("--max-steps", type=int, default=None) + args = p.parse_args() + + model = load_specific(args.ckpt) + model.eval() + print(f"# loaded model from {args.ckpt}") + + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + val_dir = dirs[-2] + val_vocs, sr = load_voc_windows(val_dir, args.n_val_vocs, args.start_offset_ms, args.auto_n) + dt = 1.0 / sr + print(f"# val_dir={val_dir} sr={sr} n_vocs={len(val_vocs)} L={len(val_vocs[0])}") + + for i, voc in enumerate(val_vocs): + print(f"\n=== voc {i} ===") + voc = np.asarray(voc, dtype=np.float64) + print(f"# target std={np.std(correct(voc)):.4e} range=[{voc.min():.3e}, {voc.max():.3e}]") + instrument_one(model, voc, dt, stride=args.stride, max_steps=args.max_steps) + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_coldstart_grid.py b/scripts/plot_coldstart_grid.py new file mode 100644 index 0000000..8218137 --- /dev/null +++ b/scripts/plot_coldstart_grid.py @@ -0,0 +1,127 @@ +"""Cold-start recon comparison: silence_pad lead-in + full voc. + +Uses load_voc_windows_coldstart so the integration IC is in silence and the model +must ignite the syllable on its own. Output: 3 rows × 4 cols (target wave, target +spec, recon wave, recon spec). The vertical dashed line marks voc onset (end of +the silence lead-in). +""" + +import argparse +import glob +import os +import shutil +import tempfile + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("Agg") +from scipy.signal import spectrogram + +from train.train import load_model +from train.eval import generate_autonomous +from examples.run_lambda_pipeline import load_voc_windows_coldstart + + +def load_specific(ckpt_path): + tmp = tempfile.mkdtemp(prefix="ckpt_plot_") + link = os.path.join(tmp, os.path.basename(ckpt_path)) + os.symlink(os.path.abspath(ckpt_path), link) + try: + model, _, _, _ = load_model(tmp) + finally: + shutil.rmtree(tmp, ignore_errors=True) + return model + + +def plot_wave(ax, x, sr, title, onset_ms=None): + if x is None or not np.isfinite(x).all(): + ax.text(0.5, 0.5, "non-finite", ha="center", va="center", transform=ax.transAxes, + fontsize=12, color="red") + ax.set_xticks([]); ax.set_yticks([]) + ax.set_title(title) + return + t = np.arange(len(x)) / sr * 1e3 + ax.plot(t, x, lw=0.6, color="steelblue") + ax.set_xlim(t[0], t[-1]) + if onset_ms is not None: + ax.axvline(onset_ms, ls="--", color="0.5", lw=0.8) + ax.set_xlabel("time (ms)") + ax.set_ylabel("amp") + ax.set_title(title) + + +def plot_spec(ax, x, sr, title, fmax=8000, onset_ms=None): + if x is None or not np.isfinite(x).all(): + ax.text(0.5, 0.5, "non-finite", ha="center", va="center", transform=ax.transAxes, + fontsize=12, color="red") + ax.set_xticks([]); ax.set_yticks([]) + ax.set_title(title) + return + nperseg = min(512, len(x)) + f, t, Sxx = spectrogram(x - np.mean(x), fs=sr, nperseg=nperseg, + noverlap=int(nperseg * 0.875), scaling="spectrum") + m = f <= fmax + S_db = 10 * np.log10(Sxx[m] + 1e-20) + ax.pcolormesh(t * 1e3, f[m] / 1e3, S_db, shading="auto", + cmap="magma", vmin=S_db.max() - 60, vmax=S_db.max()) + if onset_ms is not None: + ax.axvline(onset_ms, ls="--", color="white", lw=0.8) + ax.set_xlabel("time (ms)") + ax.set_ylabel("freq (kHz)") + ax.set_title(title) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--ckpt", required=True) + p.add_argument("--data-glob", required=True) + p.add_argument("--out", required=True) + p.add_argument("--n-vocs", type=int, default=3) + p.add_argument("--silence-pad-samples", type=int, default=2205, + help="50 ms at 44.1 kHz") + args = p.parse_args() + + model = load_specific(args.ckpt) + model.eval() + print(f"loaded {args.ckpt}") + + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + test_dir = dirs[-1] + vocs, sr = load_voc_windows_coldstart(test_dir, args.n_vocs, args.silence_pad_samples) + dt = 1.0 / sr + onset_ms = args.silence_pad_samples / sr * 1e3 + print(f"test_dir={test_dir} sr={sr} n_vocs={len(vocs)} L={len(vocs[0]) if vocs else 0} " + f"onset_at={onset_ms:.1f}ms") + + n = len(vocs) + fig, axs = plt.subplots(n, 4, figsize=(16, 3.0 * n)) + if n == 1: + axs = axs[None, :] + for i, voc in enumerate(vocs): + target = np.asarray(voc, dtype=np.float64) + # cold-start synthesis: rescale=False keeps voc-specific amplitude in the output + recon = generate_autonomous(model, voc, dt, rescale=False, detrend=True, verbose=False) + recon = np.asarray(recon, dtype=np.float64) + recon_ok = np.isfinite(recon).all() + if recon_ok: + L = min(len(target), len(recon)) + target_p, recon_p = target[:L], recon[:L] + else: + target_p, recon_p = target, None + print(f"voc {i}: recon non-finite (max |recon|={np.nanmax(np.abs(recon)):.3e})") + + plot_wave(axs[i, 0], target_p, sr, f"voc {i} target (wave)", onset_ms=onset_ms) + plot_spec(axs[i, 1], target_p, sr, f"voc {i} target (spec)", onset_ms=onset_ms) + plot_wave(axs[i, 2], recon_p, sr, f"voc {i} COLD-START recon (wave)", onset_ms=onset_ms) + plot_spec(axs[i, 3], recon_p, sr, f"voc {i} COLD-START recon (spec)", onset_ms=onset_ms) + + fig.suptitle(f"COLD-START (silence lead-in + voc) — {os.path.basename(args.ckpt)} on {os.path.basename(test_dir)}", + fontsize=12) + fig.tight_layout(rect=[0, 0, 1, 0.97]) + fig.savefig(args.out, dpi=120) + print(f"wrote {args.out}") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_recon_grid.py b/scripts/plot_recon_grid.py new file mode 100644 index 0000000..7c86533 --- /dev/null +++ b/scripts/plot_recon_grid.py @@ -0,0 +1,118 @@ +"""Side-by-side waveform + spectrogram comparison of target vs autonomous recon. + +Produces a 3-row, 4-column figure: rows = vocs, columns = target wave / target spec +/ recon wave / recon spec. Blown-up recons get a 'non-finite' placeholder. +""" + +import argparse +import glob +import os +import shutil +import tempfile + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("Agg") +from scipy.signal import spectrogram + +from train.train import load_model +from train.eval import generate_autonomous +from examples.run_lambda_pipeline import load_voc_windows + + +def load_specific(ckpt_path): + tmp = tempfile.mkdtemp(prefix="ckpt_plot_") + link = os.path.join(tmp, os.path.basename(ckpt_path)) + os.symlink(os.path.abspath(ckpt_path), link) + try: + model, _, _, _ = load_model(tmp) + finally: + shutil.rmtree(tmp, ignore_errors=True) + return model + + +def plot_wave(ax, x, sr, title): + if x is None or not np.isfinite(x).all(): + ax.text(0.5, 0.5, "non-finite", ha="center", va="center", transform=ax.transAxes, + fontsize=12, color="red") + ax.set_xticks([]); ax.set_yticks([]) + ax.set_title(title) + return + t = np.arange(len(x)) / sr * 1e3 + ax.plot(t, x, lw=0.6, color="steelblue") + ax.set_xlim(t[0], t[-1]) + ax.set_xlabel("time (ms)") + ax.set_ylabel("amp") + ax.set_title(title) + + +def plot_spec(ax, x, sr, title, fmax=8000): + if x is None or not np.isfinite(x).all(): + ax.text(0.5, 0.5, "non-finite", ha="center", va="center", transform=ax.transAxes, + fontsize=12, color="red") + ax.set_xticks([]); ax.set_yticks([]) + ax.set_title(title) + return + nperseg = min(512, len(x)) + f, t, Sxx = spectrogram(x - np.mean(x), fs=sr, nperseg=nperseg, + noverlap=int(nperseg * 0.875), scaling="spectrum") + m = f <= fmax + S_db = 10 * np.log10(Sxx[m] + 1e-20) + ax.pcolormesh(t * 1e3, f[m] / 1e3, S_db, shading="auto", + cmap="magma", vmin=S_db.max() - 60, vmax=S_db.max()) + ax.set_xlabel("time (ms)") + ax.set_ylabel("freq (kHz)") + ax.set_title(title) + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--ckpt", required=True) + p.add_argument("--data-glob", required=True) + p.add_argument("--out", required=True, help="output PNG path") + p.add_argument("--n-vocs", type=int, default=3) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--auto-n", type=int, default=3000) + args = p.parse_args() + + model = load_specific(args.ckpt) + model.eval() + print(f"loaded {args.ckpt}") + + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + test_dir = dirs[-1] + vocs, sr = load_voc_windows(test_dir, args.n_vocs, args.start_offset_ms, args.auto_n) + dt = 1.0 / sr + print(f"test_dir={test_dir} sr={sr} n_vocs={len(vocs)} L={len(vocs[0]) if vocs else 0}") + + n = len(vocs) + fig, axs = plt.subplots(n, 4, figsize=(16, 3.0 * n)) + if n == 1: + axs = axs[None, :] + for i, voc in enumerate(vocs): + target = np.asarray(voc, dtype=np.float64) + recon = generate_autonomous(model, voc, dt, rescale=True, detrend=True, verbose=False) + recon = np.asarray(recon, dtype=np.float64) + # if non-finite, mark as such; otherwise we trim to the shorter length + recon_ok = np.isfinite(recon).all() + if recon_ok: + L = min(len(target), len(recon)) + target_p, recon_p = target[:L], recon[:L] + else: + target_p, recon_p = target, None + + plot_wave(axs[i, 0], target_p, sr, f"voc {i} target (waveform)") + plot_spec(axs[i, 1], target_p, sr, f"voc {i} target (spectrogram)") + plot_wave(axs[i, 2], recon_p, sr, f"voc {i} recon (waveform)") + plot_spec(axs[i, 3], recon_p, sr, f"voc {i} recon (spectrogram)") + + fig.suptitle(f"target vs autonomous recon — {os.path.basename(args.ckpt)} on {os.path.basename(test_dir)}", + fontsize=12) + fig.tight_layout(rect=[0, 0, 1, 0.97]) + fig.savefig(args.out, dpi=120) + print(f"wrote {args.out}") + + +if __name__ == "__main__": + main() diff --git a/scripts/resynth_recon.py b/scripts/resynth_recon.py new file mode 100644 index 0000000..c082f5c --- /dev/null +++ b/scripts/resynth_recon.py @@ -0,0 +1,78 @@ +"""Generate an autonomous-reconstruction WAV from a chosen checkpoint, using the +fixed (int16-normalized) val/test loader. Also writes the target voc as a WAV for +side-by-side listening. +""" + +import argparse +import glob +import os +import shutil +import tempfile + +import numpy as np +import torch +from scipy.io import wavfile + +from train.train import load_model +from train.eval import generate_autonomous, autonomy_score +from examples.run_lambda_pipeline import load_voc_windows + + +def load_specific(ckpt_path): + tmp = tempfile.mkdtemp(prefix="ckpt_synth_") + link = os.path.join(tmp, os.path.basename(ckpt_path)) + os.symlink(os.path.abspath(ckpt_path), link) + try: + model, _, _, _ = load_model(tmp) + finally: + shutil.rmtree(tmp, ignore_errors=True) + return model + + +def to_wav(x, sr, path): + x = np.asarray(x, dtype=np.float64) + peak = float(np.abs(x).max()) + if not np.isfinite(peak) or peak < 1e-12: + print(f" WARNING: non-finite or empty signal at {path}") + return + y = (x / peak * 0.95 * 32767).astype(np.int16) + wavfile.write(path, sr, y) + print(f" wrote {path} ({len(y)} samples, sr={sr}, peak_in={peak:.4e})") + + +def main(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--ckpt", required=True) + p.add_argument("--data-glob", required=True) + p.add_argument("--out-dir", required=True) + p.add_argument("--n-vocs", type=int, default=3) + p.add_argument("--start-offset-ms", type=float, default=50.0) + p.add_argument("--auto-n", type=int, default=3000) + args = p.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + model = load_specific(args.ckpt) + model.eval() + print(f"loaded {args.ckpt}") + + dirs = sorted(d for d in glob.glob(args.data_glob) if os.path.isdir(d)) + test_dir = dirs[-1] + print(f"test_dir={test_dir}") + vocs, sr = load_voc_windows(test_dir, args.n_vocs, args.start_offset_ms, args.auto_n) + dt = 1.0 / sr + print(f"sr={sr} n_vocs={len(vocs)} L={len(vocs[0]) if vocs else 0}") + + score, per_seg, bd = autonomy_score(model, vocs, dt, rescale=True, cold_start=False) + print(f"mean rescaled autonomy = {score:+.4f} (spec={bd['spec_corr']:.2f}, " + f"pitch_pen={bd['pitch_pen']:.2f}, bounded={bd['bounded_frac']:.2f})") + print(f"per-voc scores: {[f'{s:+.3f}' for s in per_seg]}") + + for i, voc in enumerate(vocs): + target = np.asarray(voc, dtype=np.float64) + recon = generate_autonomous(model, voc, dt, rescale=True, detrend=True, verbose=False) + to_wav(target, sr, os.path.join(args.out_dir, f"target_voc{i}.wav")) + to_wav(recon, sr, os.path.join(args.out_dir, f"recon_voc{i}.wav")) + + +if __name__ == "__main__": + main() diff --git a/scripts/stage_finch_blk445_syllC.py b/scripts/stage_finch_blk445_syllC.py new file mode 100644 index 0000000..f3ef25b --- /dev/null +++ b/scripts/stage_finch_blk445_syllC.py @@ -0,0 +1,83 @@ +"""Stage blk445 / syllable_C real-finch data for the Ouroboros training pipeline. + +The pipeline (data/load_data.py::get_segmented_audio + examples/run_lambda_pipeline.py) +expects each shard directory to contain paired `{stem}.wav` and `{stem}.txt` files +side by side. The real data lives in two split locations with mismatched suffixes: + + audio: ~/isilon/All_Staff/.../blk445/{day}/double_denoised/{stem}_cleaned.wav + annotations: ~/isilon/.../blk445/segs/{day}/syllable_C/{stem}.txt + +This script symlinks (no copies) non-empty annotations and their matching cleaned +wavs into ~/ouroboros_data/blk445_syllC/day{day}/ with `{stem}.wav` / `{stem}.txt` +filenames the pipeline can glob. Skips annotations whose corresponding wav is +missing. Idempotent. +""" + +import os +import sys + +DAYS = [84, 85, 86] +SRC_ROOT = os.path.expanduser( + "~/isilon/All_Staff/birds/mooney/muscimol/Microdialysis/Muscimol/blk445" +) +DST_ROOT = os.path.expanduser("~/ouroboros_data/blk445_syllC") + + +def stage_day(day): + seg_dir = os.path.join(SRC_ROOT, "segs", str(day), "syllable_C") + wav_dir = os.path.join(SRC_ROOT, str(day), "double_denoised") + dst_dir = os.path.join(DST_ROOT, f"day{day}") + os.makedirs(dst_dir, exist_ok=True) + + if not os.path.isdir(seg_dir): + print(f"day{day}: missing seg_dir {seg_dir}", file=sys.stderr) + return 0, 0 + if not os.path.isdir(wav_dir): + print(f"day{day}: missing wav_dir {wav_dir}", file=sys.stderr) + return 0, 0 + + staged = 0 + skipped_empty = 0 + skipped_no_wav = 0 + for fn in sorted(os.listdir(seg_dir)): + if not fn.endswith(".txt"): + continue + seg_path = os.path.join(seg_dir, fn) + if os.path.getsize(seg_path) == 0: + skipped_empty += 1 + continue + stem = fn[:-len(".txt")] + wav_src = os.path.join(wav_dir, f"{stem}_cleaned.wav") + if not os.path.isfile(wav_src): + skipped_no_wav += 1 + continue + + wav_link = os.path.join(dst_dir, f"{stem}.wav") + txt_link = os.path.join(dst_dir, f"{stem}.txt") + for link, target in [(wav_link, wav_src), (txt_link, seg_path)]: + if os.path.lexists(link): + if os.readlink(link) == target: + continue + os.remove(link) + os.symlink(target, link) + staged += 1 + + print( + f"day{day}: staged={staged} (skipped {skipped_empty} empty annotations, " + f"{skipped_no_wav} with missing wav) -> {dst_dir}", + flush=True, + ) + return staged, skipped_no_wav + + +def main(): + os.makedirs(DST_ROOT, exist_ok=True) + total_staged = 0 + for day in DAYS: + n, _ = stage_day(day) + total_staged += n + print(f"TOTAL staged pairs: {total_staged}", flush=True) + + +if __name__ == "__main__": + main() From 22d24f7e95b55e846a4187a5a06cad1c6646f77f Mon Sep 17 00:00:00 2001 From: John Pearson Date: Thu, 4 Jun 2026 15:24:58 -0400 Subject: [PATCH 15/15] gitignore: finch_*/ output dirs Catches finch_blk445_syllC_arneodo_pool/ and any future per-bird/syllable run directories, matching the existing pattern used for arneodo_*/poly_*/. Co-Authored-By: Claude Opus 4.7 --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d869db3..416b628 100644 --- a/.gitignore +++ b/.gitignore @@ -187,3 +187,4 @@ poly_pipeline_full/ poly_seedcull/ poly_smallk_*/ poly_smallk_logs/ +finch_*/