From eee73d8470c6ed497a693e625bf049d7b254d66b Mon Sep 17 00:00:00 2001 From: Lorenzo Mangani Date: Thu, 18 Jun 2026 19:56:16 +0200 Subject: [PATCH 1/7] Add TRAIN.md research plan for /train LoRA experimentation UI. Documents ltx-2-mlx slice/preprocess/train pipeline, upstream APIs, RAM constraints, and phased implementation for a Web UI training lab. Co-authored-by: Cursor --- TRAIN.md | 261 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 TRAIN.md diff --git a/TRAIN.md b/TRAIN.md new file mode 100644 index 0000000..d507b1e --- /dev/null +++ b/TRAIN.md @@ -0,0 +1,261 @@ +# Training UI plan (`/train`) + +Research-backed plan for a **ltx-ws** training lab on top of [ltx-2-mlx](https://github.com/dgrauet/ltx-2-mlx) **v0.14.12** (`ltx-trainer-mlx`). + +## What “training” means upstream + +ltx-2-mlx is a three-package monorepo: + +| Package | Role | +|---------|------| +| `ltx-core-mlx` | Model weights, VAE, DiT, Gemma connectors | +| `ltx-pipelines-mlx` | Inference CLI (`generate`, `retake`, …) | +| `ltx-trainer-mlx` | **LoRA / full fine-tune** via flow matching | + +Training is **not** online learning during inference. It is an offline pipeline: + +``` +raw videos → [slice] → clips + captions → [preprocess] → latents + conditions → [train] → LoRA .safetensors +``` + +### CLI entry points (from `ltx_pipelines_mlx/cli.py`) + +All require optional package install: + +```bash +uv pip install \ + "ltx-trainer-mlx @ git+https://github.com/dgrauet/ltx-2-mlx.git@v0.14.12#subdirectory=packages/ltx-trainer" +``` + +| Command | Python API | Purpose | +|---------|------------|---------| +| `ltx-2-mlx slice` | `ltx_trainer_mlx.slice_clips.slice_videos` | Cut long sources into fixed-length, 32-aligned clips (ffmpeg; audio retained) | +| `ltx-2-mlx preprocess` | `ltx_trainer_mlx.preprocess.preprocess_dataset` | Encode clips → `.precomputed/latents/`, `conditions/`, optional `audio_latents/` | +| `ltx-2-mlx train` | `LtxvTrainer(config).train()` | Flow-matching LoRA (or full) training from preprocessed data | + +### Training strategies (validated in `tests/test_trainer_core.py`) + +| Strategy | Config `training_strategy.name` | Notes | +|----------|----------------------------------|-------| +| Text-to-video LoRA | `text_to_video` | Default; `generate_audio: false` for video-only style | +| Joint AV LoRA | `text_to_video` + `generate_audio: true` | Needs `preprocess --with-audio`; v0.14.12 audio path | +| Video-to-video (IC-LoRA) | `video_to_video` | Requires reference latents in preprocessed data; LoRA only | + +Example configs ship in upstream `packages/ltx-trainer/configs/`: + +- `lora_t2v.yaml` — basic T2V style LoRA +- `lora_v2v.yaml` — IC-LoRA / reference-video conditioning +- `lora_av_whisper.yaml` — joint audio+video (whisper/ASMR); uses `transformer-dev.safetensors`, gradient checkpointing + +### Preprocessed data layout + +``` +/ + .precomputed/ + latents/latent_0000.safetensors # video VAE latents + dims/fps + conditions/condition_0000.safetensors # Gemma prompt embeds + audio_latents/latent_0000.safetensors # optional; paired filenames +``` + +Captions: sibling `.txt` per clip (or `--captions` dir with matching stems). + +### Training runtime characteristics + +- **Single-device MLX** on Apple Silicon (unified memory); no DDP. +- **Heavy RAM**: dev transformer + Gemma + activations. `enable_gradient_checkpointing` / CLI `--low-ram` needed on ≤64 GB for dev-base LoRAs. +- **Long-running**: thousands of steps; checkpoints + validation renders on interval. +- **Outputs**: `output_dir/` with checkpoints (`.safetensors`), validation MP4s, saved YAML config. +- **Progress hook**: `LtxvTrainer.train(step_callback=fn)` — `(global_step, total_steps, validation_paths)`. +- **Conflicts with inference**: training and generation both want GPU/RAM; must not run concurrently with `server.py` generation lock. + +### Hardware guidance (from upstream configs + changelog) + +| Workflow | Typical RAM | Resolution / frames | +|----------|-------------|-------------------| +| T2V LoRA (distilled base) | 32–48 GB | 704×480 × 25 frames validation | +| AV style LoRA (dev base + checkpointing) | 64 GB | 192×192 × 97 frames | +| Preprocess only | ~16 GB peak | Encoder + Gemma partial download (v0.14.12) | + +Frame counts must stay **8k+1**; spatial dims **÷32**; training fps should stay near **24** (LTX training distribution). + +--- + +## Gap in ltx-ws today + +- Inference stack only: `ltx_mlx_backend.py`, `/api/generate`, main React UI. +- LoRA **inference** presets exist; no slice/preprocess/train orchestration. +- `ltx-trainer-mlx` not in `requirements.txt` (optional extra). +- Single generation worker; no training job queue. + +--- + +## `/train` page — product goals + +**Experimentation lab**, not a full MLOps platform: + +1. Prepare a small dataset (upload or point at folder). +2. Run preprocess with sensible defaults. +3. Configure and launch a LoRA run (T2V first; AV/V2V later). +4. Watch step progress + validation previews. +5. Register finished LoRA into existing Web UI preset list for inference smoke tests. + +--- + +## Proposed architecture + +### Frontend (`web/`) + +Add client routing (e.g. `react-router-dom`): + +| Route | Page | +|-------|------| +| `/` | Existing generator (`App.tsx` → `GeneratePage`) | +| `/train` | New `TrainPage.tsx` | + +`TrainPage` sections (wizard or tabs): + +1. **Dataset** — upload videos + caption `.txt`; or path to local folder; optional slice settings (interval, res, caption template). +2. **Preprocess** — model id, H×W, max frames, `with_audio`, frame rate; start + progress. +3. **Train** — preset picker (T2V / AV / V2V), hyperparams (rank, steps, LR, checkpoint/val intervals), validation prompts; advanced YAML toggle. +4. **Runs** — list jobs, live step/loss, validation video thumbnails, cancel, download LoRA, **“Add to LoRA library”**. + +Nav link in header: `Generate` | `Train`. + +Vite: `historyApiFallback` already handled by FastAPI `html=True` static mount. + +### Backend (`web_ui.py` + new `ltx_train_backend.py`) + +Optional dependency gate: if `ltx_trainer_mlx` missing, `/api/train/health` returns `ok: false` with install hint. + +``` +web_outputs/ + train/ + / + uploads/ # raw uploads + clips/ # post-slice + preprocessed/ # .precomputed + outputs/ # trainer output_dir + config.yaml # resolved trainer config + status.json # phase, step, logs +``` + +**API sketch:** + +| Method | Path | Purpose | +|--------|------|---------| +| GET | `/api/train/health` | Trainer package installed? ffmpeg? active MLX model path? | +| GET | `/api/train/presets` | Built-in config templates (T2V / AV / V2V) | +| POST | `/api/train/datasets` | Create dataset job dir; accept multipart uploads | +| POST | `/api/train/slice` | `{ dataset_id, interval, res, ... }` | +| POST | `/api/train/preprocess` | `{ dataset_id, model, height, width, with_audio, ... }` | +| POST | `/api/train/runs` | Build `LtxTrainerConfig`, start training | +| GET | `/api/train/runs` | List runs | +| GET | `/api/train/runs/{id}` | Status + stats | +| GET | `/api/train/runs/{id}/events` | SSE: step, loss, validation paths (mirror `/api/runs/{id}/events`) | +| POST | `/api/train/runs/{id}/cancel` | Cooperative cancel | +| GET | `/api/train/runs/{id}/artifacts/{path}` | Validation MP4s, final LoRA | +| POST | `/api/train/runs/{id}/register-lora` | Copy LoRA into `web_outputs/loras` + preset entry | + +**Worker model:** + +- Separate `TrainingWorker` thread pool (`max_workers=1`), analogous to generation executor. +- **Global mutex** with generation: starting train rejects if `server` busy generating (and vice versa) — surface clear UI message. +- Long steps run in `asyncio.to_thread()` / dedicated thread; `step_callback` pushes to asyncio queue for SSE. + +**Config builder:** + +- Start from upstream YAML templates embedded or loaded from repo `train_configs/`. +- Override: `model.model_path` ← resolved from Web UI active MLX model snapshot. +- Override: `data.preprocessed_data_root`, `output_dir`, `optimization.steps`, `lora.rank`, validation prompts. +- Validate via `LtxTrainerConfig.model_validate()` before start. + +### Dependency update + +Add commented optional install in `requirements.txt`: + +```bash +"ltx-trainer-mlx @ git+https://github.com/dgrauet/ltx-2-mlx.git@v0.14.12#subdirectory=packages/ltx-trainer" +``` + +Pin tag to `LTX2_MLX_GIT_TAG` in `ltx_mlx_backend.py`. + +--- + +## Implementation phases + +### Phase 1 — Foundation (MVP) + +- [ ] `training` branch: router + empty `/train` shell + nav +- [ ] `ltx_train_backend.py`: healthcheck, config templates, subprocess/thread wrapper for **preprocess only** +- [ ] Dataset upload API + folder layout +- [ ] Train page: upload videos + captions → preprocess button → log/progress panel +- [ ] Docs in README: optional trainer install + +**Exit criteria:** User can preprocess clips from Web UI; no training yet. + +### Phase 2 — T2V LoRA training + +- [ ] `POST /api/train/runs` wrapping `LtxvTrainer.train(step_callback=...)` +- [ ] SSE progress (step / total / ETA from `TrainingStats`) +- [ ] Cancel flag checked between steps +- [ ] Validation MP4 serving from `output_dir` +- [ ] **Register LoRA** → existing `/api/loras/custom` flow + +**Exit criteria:** End-to-end T2V LoRA on a toy dataset (≥2 clips); use in generator. + +### Phase 3 — Slice + AV presets + +- [ ] Slice API (ffmpeg dependency check) +- [ ] `with_audio` preprocess toggle +- [ ] Presets: `lora_av_whisper` simplified form (audio-only target modules hidden behind preset) +- [ ] RAM warning banners (`--low-ram` → `enable_gradient_checkpointing`) + +### Phase 4 — V2V / IC-LoRA training + +- [ ] Reference video upload + reference latent preprocess path +- [ ] `video_to_video` strategy UI +- [ ] Validation with `reference_videos` + +### Phase 5 — Polish + +- [ ] Resume from checkpoint (`model.load_checkpoint`) +- [ ] W&B optional (`wandb` extra) +- [ ] MCP tool `ltx_train_lora` for agents (optional) + +--- + +## Risks and constraints + +| Risk | Mitigation | +|------|------------| +| OOM during train | Default to q8 model path; expose checkpointing; block train if free RAM estimate low | +| Train + generate concurrent | Global `mlx_busy` lock shared with `LocalVideoGenerator` | +| Preprocess partial HF download | Use same resolved `model_path` as inference (full snapshot already cached) | +| V2V reference latents | Defer to Phase 4; document manual preprocess steps until automated | +| Long jobs lost on server restart | Persist `status.json`; optional resume; warn user | +| Alpha trainer API | Pin v0.14.12; thin adapter layer in `ltx_train_backend.py` | + +--- + +## Testing strategy + +1. **Unit** (no MLX): config builder, path layout, `_sync` job state, YAML merge. +2. **Integration** (MLX machine): preprocess 2 clips → 50-step T2V LoRA → load in generator. +3. **Manual**: `/train` SSE progress, cancel mid-run, register LoRA preset. + +Upstream tests to mirror behavior: `tests/test_trainer_core.py`, `tests/test_trainer_datasets.py`. + +--- + +## Open questions (decide before Phase 2) + +1. **Default base weights for training** — always `transformer-dev.safetensors` (inference-compatible CFG pipelines) vs distilled (faster but different inference path)? +2. **Dataset size limits** — cap uploads (e.g. 2 GB / 50 clips) for Web UI? +3. **Separate process** — run trainer in child process for crash isolation vs in-process thread? +4. **Standalone `web_server.py`** — should `/train` work without full `server.py` WS stack? (Recommend: yes, train-only via FastAPI.) + +--- + +## Immediate next step + +Implement **Phase 1** on `training` branch: routing, `ltx_train_backend.py` skeleton, preprocess job API, minimal `/train` UI. From 2a5cccfada13b5d0efcf4b31ef3d3bbf90831239 Mon Sep 17 00:00:00 2001 From: Lorenzo Mangani Date: Thu, 18 Jun 2026 20:00:24 +0200 Subject: [PATCH 2/7] Expand TRAIN.md with input requirements and SSE job wiring plan. Documents slice/preprocess/train inputs, preset matrix, TrainJob phases, and how to mirror generation SSE for long-running background training. Co-authored-by: Cursor --- TRAIN.md | 299 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 299 insertions(+) diff --git a/TRAIN.md b/TRAIN.md index d507b1e..5274631 100644 --- a/TRAIN.md +++ b/TRAIN.md @@ -259,3 +259,302 @@ Upstream tests to mirror behavior: `tests/test_trainer_core.py`, `tests/test_tra ## Immediate next step Implement **Phase 1** on `training` branch: routing, `ltx_train_backend.py` skeleton, preprocess job API, minimal `/train` UI. + +--- + +## Training inputs (what ltx-2-mlx accepts & requires) + +Training is three optional/required stages. **Train** only consumes **preprocessed** data; everything before that is dataset prep. + +### Stage A — Slice (optional) + +**API:** `ltx_trainer_mlx.slice_clips.slice_videos` · **Requires:** `ffmpeg` on PATH + +| Input | Required | Notes | +|-------|----------|-------| +| `sources` | yes | One or more video files or directories | +| `out_dir` | yes | Per-source subfolders of clips | +| `interval` | no (default 4s) | Clip length; ignored if `timecodes_file` set | +| `timecodes_file` | no | `start,end` per line | +| `res` | no (default `384x384`) | `WxH`, both **÷32** | +| `fps` | no (default 24) | Output fps | +| `fit` | no | `crop` or `pad` | +| `min_length` | no | Drop clips shorter than N seconds | +| `max_clips` / `sample` | no | Cap + even/sequential sampling | +| `skip_start` / `skip_end` | no | Trim intros/outros | +| `caption_template` | no | Writes identical `.txt` beside each clip | +| `crf` | no | x264 quality | + +**Outputs:** `clip_XXX.mp4` + optional `clip_XXX.txt` (caption seed for editing). + +--- + +### Stage B — Preprocess (required before train) + +**API:** `ltx_trainer_mlx.preprocess.preprocess_dataset` · **Requires:** local MLX model dir, Gemma (HF id) + +| Input | Required | Notes | +|-------|----------|-------| +| `videos_dir` | yes | `.mp4/.mov/.avi/.mkv/.webm`; recursive (slice subfolders OK) | +| `output_dir` | yes | Creates `output_dir/.precomputed/` | +| `model_dir` | yes | **Local path** to MLX snapshot (encoders only; partial HF download OK in v0.14.12) | +| `gemma_model_id` | no | Default `mlx-community/gemma-3-12b-it-4bit` | +| `target_height` / `target_width` | no | **÷32**; default = native per clip | +| `max_frames` | no | Default 97; must be **8k+1** | +| `captions_dir` | no | `.txt` per video stem; else **filename stem** used as prompt | +| `caption_ext` | no | Default `.txt` | +| `with_audio` | no | Adds `audio_latents/`; **required** if training with `generate_audio: true` | +| `frame_rate` | no | Written into latent metadata; default = probed fps | + +**Outputs:** + +``` +/.precomputed/ + latents/latent_0000.safetensors # video VAE latent + dims/fps metadata + conditions/condition_0000.safetensors # Gemma video+audio prompt embeds + audio_latents/latent_0000.safetensors # optional; same index as video latent +``` + +**V2V add-on (Phase 4):** IC-LoRA also needs `reference_latents/latent_XXXX.safetensors` paired by index (separate encode pass — not in basic `preprocess_dataset` today; manual or custom script per upstream `lora_v2v.yaml`). + +--- + +### Stage C — Train (LoRA / full) + +**API:** `LtxvTrainer(LtxTrainerConfig).train(step_callback=…)` · **Requires:** `ltx-trainer-mlx`, preprocessed data, **local** `model.model_path` + +#### Hard requirements (`LtxTrainerConfig` validation) + +| Field | Required | Notes | +|-------|----------|-------| +| `model.model_path` | yes | Existing **local** directory (not URL) | +| `model.text_encoder_path` | yes* | Gemma id/path; *skipped if no validation prompts | +| `model.training_mode` | yes | `lora` (UI default) or `full` | +| `lora` block | yes if `lora` mode | `rank`, `alpha`, `dropout`, `target_modules` | +| `data.preprocessed_data_root` | yes | Path to dataset root (parent of `.precomputed`) | +| `training_strategy.name` | yes | `text_to_video` or `video_to_video` | +| `optimization.steps` | yes | Often 1000–3000+ | +| `output_dir` | yes | Checkpoints + validation MP4s | + +#### Common optional / preset fields + +| Field | Purpose | +|-------|---------| +| `model.transformer_file` | e.g. `transformer-dev.safetensors` (AV/style LoRAs) | +| `model.load_checkpoint` | Resume from prior checkpoint dir/file | +| `training_strategy.generate_audio` | Joint AV training (needs audio latents) | +| `optimization.enable_gradient_checkpointing` | **Required** on 64 GB for dev-base; maps to CLI `--low-ram` | +| `optimization.batch_size`, `gradient_accumulation_steps`, `learning_rate`, schedulers | Standard training knobs | +| `validation.*` | Prompts, `video_dims` (W,H,F), `interval`, `inference_steps`, `reference_videos` (V2V) | +| `checkpoints.interval` / `keep_last_n` | Intermediate `.safetensors` | +| `flow_matching.timestep_sampling_mode` | Default `shifted_logit_normal` | +| `seed` | Reproducibility | +| `wandb.*` / `hub.*` | Off by default in UI | + +#### Strategy matrix (what we wire first) + +| Preset | `training_strategy` | Preprocess | `transformer_file` | RAM hint | +|--------|---------------------|------------|----------------------|----------| +| **T2V style** | `text_to_video`, `generate_audio: false` | standard | auto (distilled OK) | 32–48 GB | +| **AV style** | `text_to_video`, `generate_audio: true` | `--with-audio` | `transformer-dev.safetensors` | 64 GB + checkpointing | +| **IC-LoRA V2V** | `video_to_video`, LoRA only | + `reference_latents/` | LoRA | defer Phase 4 | + +#### Trainer outputs + +- `output_dir/checkpoint-XXXX.safetensors` (LoRA weights) +- `output_dir/validation_step_XXXX_*.mp4` (when `validation.interval` set) +- `output_dir/config.yaml` (resolved config copy) +- Final return: `(saved_path: Path, TrainingStats)` — steps/sec, peak GB, total time + +#### Trainer progress hooks (for our adapter) + +| Hook | Data available | +|------|----------------| +| `step_callback(global_step, total_steps, validation_paths)` | Step index, validation MP4 paths after val steps | +| `TrainingProgress.update_training` | `loss`, `lr`, `step_time` (internal — we patch or subclass to expose) | +| `disable_progress_bars=True` | Logs loss every 5 steps to logger (fallback) | + +**No built-in cancel** — cooperative cancel via `step_callback` raising `TrainingCancelledError` between steps. + +--- + +## What we wire in ltx-ws (scope by preset) + +### Phase 1–2 UI fields → upstream mapping + +| UI control | Maps to | +|------------|---------| +| Upload videos + caption files | `videos_dir` (+ optional `captions_dir`) | +| “Slice first” toggle + interval/res/fps/template | `slice_videos(...)` → `clips/` | +| Model picker | `model_dir` = same resolved snapshot as inference (`state.active_model`) | +| Resolution / max frames / with audio | `preprocess_dataset(...)` | +| Preset: T2V / AV | Load embedded YAML template → override paths & steps | +| Rank, steps, LR, val interval, val prompts | `LtxTrainerConfig` overrides | +| Low RAM toggle | `optimization.enable_gradient_checkpointing=true` | +| Run name | `output_dir` subfolder + preset label | + +### Phase 4 additions (V2V) + +| UI control | Maps to | +|------------|---------| +| Reference video per target clip | `reference_latents/` preprocess + `validation.reference_videos` | +| `reference_downscale_factor` | validation config | + +### Out of scope for v1 UI (CLI / advanced YAML only) + +- `training_mode: full` (full fine-tune) +- W&B / Hub push +- Custom `target_modules` (expose in “Advanced YAML” panel later) +- Timecode-list slicing + +--- + +## Long-running background jobs & client updates + +Reuse the **generation run pattern** in `web_ui.py` — it already solves queueing, SSE, cancel, and persistence. Training jobs are longer and multi-phase but fit the same model. + +### Job model: `TrainJob` (extends run concepts) + +One **`job_id`** spans all phases (not three separate IDs): + +```text +phase: queued → slicing → preprocessing → training → done | failed | cancelled +``` + +Persisted to `web_outputs/train//status.json` (+ index in `settings.json`). + +```json +{ + "job_id": "...", + "phase": "training", + "preset": "t2v", + "created_at": "...", + "step": 420, + "total_steps": 3000, + "loss": 0.0842, + "lr": 0.00035, + "eta_s": 3600, + "peak_memory_gb": 28.4, + "validation_clips": [{"step": 400, "url": "/api/train/jobs/.../validation/400_0.mp4"}], + "artifact_lora": "/api/train/jobs/.../lora.safetensors", + "error": null +} +``` + +### Worker architecture + +``` +┌─────────────────────────────────────────────────────────┐ +│ FastAPI (web_ui) │ +│ POST /api/train/jobs → enqueue job_id │ +│ GET /api/train/jobs/{id}/events → SSE (EventSource) │ +└───────────────────────┬─────────────────────────────────┘ + │ + ┌───────────────▼───────────────┐ + │ _train_worker_loop (async) │ ← mirror _worker_loop + │ asyncio.Queue[job_id] │ + └───────────────┬───────────────┘ + │ asyncio.to_thread() + ┌───────────────▼───────────────┐ + │ ltx_train_backend.py │ + │ · slice_videos (ffmpeg) │ + │ · preprocess_dataset (MLX) │ + │ · LtxvTrainer.train (MLX) │ + └───────────────────────────────┘ +``` + +**MLX exclusivity:** shared `AppState.mlx_busy: asyncio.Lock` — training and generation cannot overlap (same as today’s single gen executor). `POST /api/generate` returns 409 if train active; `POST /api/train/jobs` returns 409 if gen active. + +**Threading:** MLX training blocks the GIL/Metal for minutes–hours; run entire `slice` / `preprocess` / `train` in **`asyncio.to_thread()`** (or dedicated `ThreadPoolExecutor(max_workers=1)`), same as LoRA downloads. Main asyncio loop stays responsive for SSE pings. + +### SSE event schema (mirror `/api/runs/{id}/events`) + +Client uses **`EventSource`** on `/api/train/jobs/{job_id}/events` (same as `subscribeRun` in `App.tsx`). Optional later: WS `train_progress` for raw `server.py` clients. + +| Event `type` | When | Payload | +|--------------|------|---------| +| `job_started` | Job dequeued | `job_id`, `preset`, `phases` | +| `phase_started` | slice / preprocess / train begin | `phase`, `message` | +| `phase_progress` | preprocess clip N/M | `phase`, `current`, `total`, `message` | +| `train_step` | each optim step | `step`, `total`, `loss`, `lr`, `step_time_s`, `eta_s`, `peak_memory_gb` | +| `train_validation` | val interval | `step`, `videos: [{url, prompt}]` | +| `train_checkpoint` | checkpoint saved | `step`, `path` | +| `ping` | 120s idle | `{}` | +| `job_done` | success | `artifact_lora`, `stats`, `register_lora_url` | +| `error` | failure | `message`, `phase` | +| `job_complete` | always (finally) | `job_id`, `status` | + +**Loss streaming:** wrap `TrainingProgress.update_training` in `ltx_train_backend.py` to push `loss`/`lr` into a thread-safe queue drained by the training thread’s `step_callback`. Avoid duplicating the 200-line train loop. + +**Cancel:** `POST /api/train/jobs/{id}/cancel` sets `job.cancelled=True`; `step_callback` checks flag and raises `TrainingCancelledError` → `phase: cancelled`, emit `job_complete`. + +**Reconnect:** SSE handler replays `status.json` snapshot then attaches to live queue (same pattern as completed runs in `run_events`). + +### Frontend (`/train`) + +- `subscribeTrainJob(jobId)` — clone of `subscribeRun` with `train_step` / `train_validation` handlers +- Progress bar: reuse `formatProgressMessage` / `formatMmSs` from `progress.ts` +- Phase stepper: Slice → Preprocess → Train +- Validation gallery: thumbnails from `train_validation` events +- **Background-friendly:** user can navigate away; job continues; reconnect via job list + SSE +- Header badge: “Training step 420/3000” when job active (poll `/api/train/jobs/active` or keep SSE open globally) + +### WebSocket (optional Phase 2b) + +For `server.py` WS clients (videofentanyl), add message types parallel to generation: + +- `train_job_status` — polled or pushed during training +- Not required for Web UI (SSE is enough and already works through Vite proxy) + +--- + +## Minimal API surface (revised) + +| Method | Path | Purpose | +|--------|------|---------| +| GET | `/api/train/health` | `ltx_trainer_mlx` installed, ffmpeg, model path resolved | +| GET | `/api/train/presets` | T2V / AV templates + field metadata | +| POST | `/api/train/jobs` | Create job: uploads refs OR multipart in same request | +| GET | `/api/train/jobs` | List jobs (active + history) | +| GET | `/api/train/jobs/{id}` | `status.json` snapshot | +| GET | `/api/train/jobs/{id}/events` | **SSE** progress stream | +| POST | `/api/train/jobs/{id}/cancel` | Cooperative cancel | +| GET | `/api/train/jobs/{id}/artifacts/{name}` | LoRA, validation MP4s | +| POST | `/api/train/jobs/{id}/register-lora` | → existing custom LoRA preset | + +Single **`POST /api/train/jobs`** body (Phase 2): + +```json +{ + "preset": "t2v", + "name": "my_style_lora", + "slice": { "enabled": false }, + "preprocess": { "width": 704, "height": 480, "max_frames": 97, "with_audio": false, "frame_rate": 24 }, + "train": { "steps": 2000, "rank": 64, "learning_rate": 5e-4, "validation_prompts": ["..."], "validation_interval": 500, "checkpoint_interval": 500, "low_ram": false }, + "video_paths": ["uploaded-id-1", "uploaded-id-2"], + "caption_paths": ["uploaded-id-1.txt"] +} +``` + +--- + +## Revised implementation phases + +### Phase 1 — Job shell + preprocess +- `TrainJob` + worker queue + SSE skeleton +- Multipart upload → `web_outputs/train//raw/` +- Preprocess phase only; `phase_progress` events + +### Phase 2 — T2V training end-to-end +- `LtxvTrainer` wrapper + `train_step` / `train_validation` SSE +- Cancel + `status.json` persistence +- Register LoRA → inference presets +- MLX lock vs generation + +### Phase 3 — Slice + AV preset +- Slice phase in job pipeline +- `with_audio` + `lora_av` simplified preset + +### Phase 4 — V2V +- Reference latent preprocess + validation reference videos + From e7cbf9f98511189ef4b9e4aea7de39057978d806 Mon Sep 17 00:00:00 2001 From: Lorenzo Mangani Date: Thu, 18 Jun 2026 20:08:49 +0200 Subject: [PATCH 3/7] Use durable local paths for training; prefer HF hub cache for weights. Drop /tmp placeholders from train presets; inject paths under web_outputs/train/. Resolve MLX weights from local dirs, VIDEOFENTANYL_MODELS, or HF hub cache before downloading. Copy finished LoRAs to repo loras/ or VIDEOFENTANYL_LORA_DIR. Co-authored-by: Cursor --- TRAIN.md | 10 + ltx_mlx_backend.py | 48 ++++ ltx_train_backend.py | 502 ++++++++++++++++++++++++++++++++++++ requirements.txt | 4 + train_configs/lora_av.yaml | 53 ++++ train_configs/lora_t2v.yaml | 56 ++++ 6 files changed, 673 insertions(+) create mode 100644 ltx_train_backend.py create mode 100644 train_configs/lora_av.yaml create mode 100644 train_configs/lora_t2v.yaml diff --git a/TRAIN.md b/TRAIN.md index 5274631..a54e2a0 100644 --- a/TRAIN.md +++ b/TRAIN.md @@ -423,6 +423,16 @@ phase: queued → slicing → preprocessing → training → done | failed | can Persisted to `web_outputs/train//status.json` (+ index in `settings.json`). +**Storage policy (no `/tmp`):** + +| Asset | Location | +|-------|----------| +| Uploads, clips, preprocessed latents, checkpoints, validation MP4s | `/train//` | +| Base MLX weights (preprocess + train) | Local path, `$VIDEOFENTANYL_MODELS` / `/models/`, or existing **HF hub cache** (`HF_HOME` / `~/.cache/huggingface`) via `resolve_mlx_weights_directory` | +| Finished LoRA for inference | Copied to `$VIDEOFENTANYL_LORA_DIR` or `/loras/` when registered | + +Preset YAML files under `train_configs/` hold **hyperparameters only** — paths are injected at job start. + ```json { "job_id": "...", diff --git a/ltx_mlx_backend.py b/ltx_mlx_backend.py index 06d0ec9..b038533 100644 --- a/ltx_mlx_backend.py +++ b/ltx_mlx_backend.py @@ -141,6 +141,50 @@ def _model_snapshot_present(dest: Path) -> bool: return bool(has_config and has_weights) +def _hf_hub_cache_roots() -> list[Path]: + """Candidate Hugging Face hub cache roots (``HF_HOME``, ``HUGGINGFACE_HUB_CACHE``, defaults).""" + candidates: list[Path] = [] + hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE", "").strip() + if hub_cache: + candidates.append(Path(hub_cache).expanduser().resolve()) + hf_home = os.environ.get("HF_HOME", "").strip() + if hf_home: + candidates.append(Path(hf_home).expanduser().resolve()) + xdg = os.environ.get("XDG_CACHE_HOME", "").strip() + if xdg: + candidates.append((Path(xdg).expanduser() / "huggingface").resolve()) + candidates.append((Path.home() / ".cache" / "huggingface").resolve()) + seen: set[Path] = set() + unique: list[Path] = [] + for path in candidates: + if path not in seen: + seen.add(path) + unique.append(path) + return unique + + +def find_hf_hub_snapshot(repo_id: str) -> Path | None: + """Return the newest materialized weights tree under the HF hub cache, if any.""" + slug = repo_id.strip().replace("/", "--") + best: Path | None = None + best_mtime = -1.0 + for cache_root in _hf_hub_cache_roots(): + snaps_dir = cache_root / "hub" / f"models--{slug}" / "snapshots" + if not snaps_dir.is_dir(): + continue + try: + for snap in snaps_dir.iterdir(): + if not snap.is_dir() or not _model_snapshot_present(snap): + continue + mtime = snap.stat().st_mtime + if mtime >= best_mtime: + best_mtime = mtime + best = snap.resolve() + except OSError: + continue + return best + + def hf_local_weights_directory(repo_id: str, explicit_model_dir: str | None) -> Path: """ Directory where we store a full ``snapshot_download`` for ``repo_id``. @@ -253,6 +297,10 @@ def resolve_mlx_weights_directory(model: str, explicit_model_dir: str | None) -> "Install with: pip install huggingface_hub\n" "Or use a local directory for --model." ) from e + hub_snap = find_hf_hub_snapshot(raw) + if hub_snap is not None: + log.info("Using Hugging Face hub cache snapshot for %r at %s", raw, hub_snap) + return str(hub_snap) dest = hf_local_weights_directory(raw, explicit_model_dir) dest.mkdir(parents=True, exist_ok=True) if _model_snapshot_present(dest): diff --git a/ltx_train_backend.py b/ltx_train_backend.py new file mode 100644 index 0000000..04d1d81 --- /dev/null +++ b/ltx_train_backend.py @@ -0,0 +1,502 @@ +# SPDX-License-Identifier: Apache-2.0 +"""MLX LoRA training adapter for ltx-trainer-mlx (optional dependency). + +Storage policy (no ``/tmp``): +- Per-job artifacts live under ``/train//`` (raw uploads, clips, + preprocessed latents, checkpoints, validation MP4s). +- Base MLX weights resolve via :func:`resolve_mlx_weights_directory` — explicit local + path, ``$VIDEOFENTANYL_MODELS`` / ``/models/``, or an existing Hugging Face + hub cache snapshot before downloading. +- Finished LoRAs copied for inference via :func:`register_trained_lora` into + ``$VIDEOFENTANYL_LORA_DIR`` or ``/loras/``. +""" + +from __future__ import annotations + +import logging +import re +import shutil +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable + +import yaml + +from ltx_mlx_backend import ( + LTX2_MLX_GIT_TAG, + _local_lora_cache_dir, + _nearest_valid_frames, + resolve_mlx_weights_directory, +) + +log = logging.getLogger("ltx_train") + +REPO_ROOT = Path(__file__).resolve().parent +TRAIN_CONFIGS_DIR = REPO_ROOT / "train_configs" +DEFAULT_GEMMA = "mlx-community/gemma-3-12b-it-4bit" + +TRAINER_INSTALL_HINT = ( + f'uv pip install "ltx-trainer-mlx @ git+https://github.com/dgrauet/ltx-2-mlx.git@{LTX2_MLX_GIT_TAG}' + f'#subdirectory=packages/ltx-trainer"' +) + + +class TrainingCancelledError(Exception): + """Raised when a cooperative training cancel is requested.""" + + +@dataclass +class TrainPresetInfo: + id: str + label: str + description: str + ram_hint: str + with_audio: bool + low_ram_default: bool + + +TRAIN_PRESETS: dict[str, TrainPresetInfo] = { + "t2v": TrainPresetInfo( + id="t2v", + label="Text-to-video style", + description="Video-only LoRA on the default distilled/dev stack.", + ram_hint="32–48 GB unified memory", + with_audio=False, + low_ram_default=False, + ), + "av": TrainPresetInfo( + id="av", + label="Audio + video style", + description="Joint AV LoRA (whisper/ASMR-style); dev transformer + checkpointing.", + ram_hint="64 GB recommended", + with_audio=True, + low_ram_default=True, + ), +} + + +@dataclass +class SliceOptions: + enabled: bool = False + interval: float = 4.0 + res: str = "384x384" + fps: float = 24.0 + fit: str = "crop" + caption_template: str | None = None + max_clips: int | None = None + + +@dataclass +class PreprocessOptions: + width: int | None = 704 + height: int | None = 480 + max_frames: int = 97 + with_audio: bool = False + frame_rate: float | None = 24.0 + + +@dataclass +class TrainHyperparams: + steps: int = 2000 + rank: int = 64 + learning_rate: float = 5e-4 + validation_prompts: list[str] = field(default_factory=lambda: ["a cinematic landscape at sunset"]) + validation_interval: int = 500 + checkpoint_interval: int = 500 + low_ram: bool = False + seed: int = 42 + + +@dataclass +class TrainJobRequest: + preset: str + name: str + model_id: str + model_dir: str | None + slice: SliceOptions = field(default_factory=SliceOptions) + preprocess: PreprocessOptions = field(default_factory=PreprocessOptions) + train: TrainHyperparams = field(default_factory=TrainHyperparams) + + +EventCallback = Callable[[dict[str, Any]], None] +CancelCheck = Callable[[], bool] + + +def trainer_available() -> bool: + try: + import ltx_trainer_mlx # noqa: F401 + + return True + except ImportError: + return False + + +def trainer_health(*, ffmpeg_required: bool = False) -> dict[str, Any]: + ok = trainer_available() + ffmpeg = bool(shutil.which("ffmpeg")) + return { + "ok": ok and (ffmpeg or not ffmpeg_required), + "trainer_installed": ok, + "ffmpeg_available": ffmpeg, + "install_hint": None if ok else TRAINER_INSTALL_HINT, + "presets": [p.__dict__ for p in TRAIN_PRESETS.values()], + "configs_dir": str(TRAIN_CONFIGS_DIR), + } + + +def job_root(output_dir: Path, job_id: str) -> Path: + return output_dir.resolve() / "train" / job_id + + +@dataclass(frozen=True) +class TrainJobPaths: + """All mutable training artifacts for one job (under ``web_outputs``).""" + + root: Path + raw: Path + clips: Path + captions: Path + preprocessed: Path + outputs: Path + config: Path + + def ensure_dirs(self) -> None: + for d in (self.root, self.raw, self.clips, self.captions, self.preprocessed, self.outputs): + d.mkdir(parents=True, exist_ok=True) + + +def training_job_paths(output_dir: Path, job_id: str) -> TrainJobPaths: + root = job_root(output_dir, job_id) + return TrainJobPaths( + root=root, + raw=root / "raw", + clips=root / "clips", + captions=root / "captions", + preprocessed=root / "preprocessed", + outputs=root / "outputs", + config=root / "config.yaml", + ) + + +def register_trained_lora(lora_path: Path, *, name: str) -> Path: + """Copy a finished LoRA into the persistent local cache for inference presets.""" + src = lora_path.expanduser().resolve() + if not src.is_file(): + raise FileNotFoundError(f"LoRA weights not found: {src}") + dest_dir = _local_lora_cache_dir() + dest_dir.mkdir(parents=True, exist_ok=True) + slug = re.sub(r"[^\w.\-]+", "_", (name or "trained_lora").strip()).strip("._") or "trained_lora" + dest = (dest_dir / f"{slug}.safetensors").resolve() + shutil.copy2(src, dest) + return dest + + +def status_path(output_dir: Path, job_id: str) -> Path: + return job_root(output_dir, job_id) / "status.json" + + +def load_status(output_dir: Path, job_id: str) -> dict[str, Any] | None: + path = status_path(output_dir, job_id) + if not path.is_file(): + return None + try: + import json + + return json.loads(path.read_text(encoding="utf-8")) + except (OSError, ValueError): + return None + + +def save_status(output_dir: Path, job_id: str, payload: dict[str, Any]) -> None: + import json + + root = job_root(output_dir, job_id) + root.mkdir(parents=True, exist_ok=True) + path = status_path(output_dir, job_id) + path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + +def _preset_yaml_path(preset: str) -> Path: + key = (preset or "t2v").strip().lower() + path = TRAIN_CONFIGS_DIR / f"lora_{key}.yaml" + if not path.is_file(): + path = TRAIN_CONFIGS_DIR / "lora_t2v.yaml" + if not path.is_file(): + raise FileNotFoundError(f"No training preset config for {preset!r}") + return path + + +def build_trainer_config(req: TrainJobRequest, *, paths: TrainJobPaths) -> Any: + from ltx_trainer_mlx.config import LtxTrainerConfig + + raw = yaml.safe_load(_preset_yaml_path(req.preset).read_text(encoding="utf-8")) + model_path = resolve_mlx_weights_directory(req.model_id, req.model_dir) + + model_block = dict(raw.get("model") or {}) + model_block["model_path"] = model_path + raw["model"] = model_block + raw["data"] = {"preprocessed_data_root": str(paths.preprocessed.resolve())} + raw["output_dir"] = str(paths.outputs.resolve()) + raw["seed"] = int(req.train.seed) + + raw["optimization"]["steps"] = int(req.train.steps) + raw["optimization"]["learning_rate"] = float(req.train.learning_rate) + if req.train.low_ram: + raw["optimization"]["enable_gradient_checkpointing"] = True + + lora = raw.get("lora") or {} + lora["rank"] = int(req.train.rank) + lora["alpha"] = int(req.train.rank) + raw["lora"] = lora + + prompts = [p.strip() for p in req.train.validation_prompts if str(p).strip()] + if not prompts: + prompts = ["a cinematic landscape at sunset"] + val = raw.get("validation") or {} + val["prompts"] = prompts + val["interval"] = int(req.train.validation_interval) + val["skip_initial_validation"] = True + w = int(req.preprocess.width or 704) + h = int(req.preprocess.height or 480) + nf = _nearest_valid_frames(int(req.preprocess.max_frames)) + val["video_dims"] = [w, h, nf] + val["frame_rate"] = float(req.preprocess.frame_rate or 24.0) + val["generate_audio"] = bool(TRAIN_PRESETS.get(req.preset, TRAIN_PRESETS["t2v"]).with_audio) + raw["validation"] = val + + ckpt = raw.get("checkpoints") or {} + ckpt["interval"] = int(req.train.checkpoint_interval) + raw["checkpoints"] = ckpt + + strat = raw.get("training_strategy") or {} + preset_info = TRAIN_PRESETS.get(req.preset, TRAIN_PRESETS["t2v"]) + strat["generate_audio"] = preset_info.with_audio + raw["training_strategy"] = strat + + return LtxTrainerConfig(**raw) + + +@contextmanager +def _metrics_hook(on_metrics: Callable[[dict[str, float]], None] | None): + if on_metrics is None: + yield + return + from ltx_trainer_mlx import progress as progress_mod + + orig_cls = progress_mod.TrainingProgress + orig_update = orig_cls.update_training + + def patched_update( + self, + *, + loss: float, + lr: float, + step_time: float, + advance: bool = True, + ) -> None: + orig_update(self, loss=loss, lr=lr, step_time=step_time, advance=advance) + if advance: + try: + on_metrics({"loss": float(loss), "lr": float(lr), "step_time_s": float(step_time)}) + except Exception: + pass + + progress_mod.TrainingProgress.update_training = patched_update # type: ignore[method-assign] + try: + yield + finally: + progress_mod.TrainingProgress.update_training = orig_update # type: ignore[method-assign] + + +def _check_cancel(should_cancel: CancelCheck | None) -> None: + if should_cancel and should_cancel(): + raise TrainingCancelledError("Training cancelled") + + +def run_train_job( + req: TrainJobRequest, + *, + output_dir: Path, + job_id: str, + on_event: EventCallback | None = None, + should_cancel: CancelCheck | None = None, +) -> dict[str, Any]: + """Execute slice → preprocess → train for one job.""" + if not trainer_available(): + raise RuntimeError(f"ltx-trainer-mlx is not installed. {TRAINER_INSTALL_HINT}") + + paths = training_job_paths(output_dir, job_id) + paths.ensure_dirs() + + status: dict[str, Any] = { + "job_id": job_id, + "name": req.name, + "preset": req.preset, + "phase": "queued", + "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "step": 0, + "total_steps": int(req.train.steps), + "job_dir": str(paths.root), + "model_path": None, + "error": None, + } + save_status(output_dir, job_id, status) + + def emit(event: dict[str, Any]) -> None: + nonlocal status + status.update({k: v for k, v in event.items() if k != "type"}) + save_status(output_dir, job_id, status) + if on_event: + on_event(event) + + videos_dir = paths.raw + captions_dir: str | None = None + + try: + if req.slice.enabled: + _check_cancel(should_cancel) + status["phase"] = "slicing" + emit({"type": "phase_started", "phase": "slicing", "message": "Slicing source videos…"}) + if not shutil.which("ffmpeg"): + raise RuntimeError("ffmpeg is required for slice") + from ltx_trainer_mlx.slice_clips import slice_videos + + sources = sorted( + p + for p in paths.raw.iterdir() + if p.suffix.lower() in {".mp4", ".mov", ".avi", ".mkv", ".webm"} + ) + if not sources: + raise ValueError("No video files found in upload") + count = slice_videos( + [str(p) for p in sources], + str(paths.clips), + interval=float(req.slice.interval), + res=str(req.slice.res), + fps=float(req.slice.fps), + fit=str(req.slice.fit), + caption_template=req.slice.caption_template, + max_clips=req.slice.max_clips, + ) + emit({"type": "phase_progress", "phase": "slicing", "message": f"Created {count} clips"}) + videos_dir = paths.clips + captions_dir = None + else: + txts = list(paths.raw.glob("*.txt")) + if txts: + paths.captions.mkdir(parents=True, exist_ok=True) + for t in txts: + shutil.copy2(t, paths.captions / t.name) + captions_dir = str(paths.captions) + + _check_cancel(should_cancel) + status["phase"] = "preprocessing" + emit({"type": "phase_started", "phase": "preprocessing", "message": "Encoding latents…"}) + from ltx_trainer_mlx.preprocess import preprocess_dataset + + model_path = resolve_mlx_weights_directory(req.model_id, req.model_dir) + status["model_path"] = model_path + nf = _nearest_valid_frames(int(req.preprocess.max_frames)) + preset_info = TRAIN_PRESETS.get(req.preset, TRAIN_PRESETS["t2v"]) + with_audio = req.preprocess.with_audio or preset_info.with_audio + preprocess_dataset( + videos_dir=str(videos_dir), + output_dir=str(paths.preprocessed), + model_dir=model_path, + gemma_model_id=DEFAULT_GEMMA, + target_height=int(req.preprocess.height) if req.preprocess.height else None, + target_width=int(req.preprocess.width) if req.preprocess.width else None, + max_frames=nf, + captions_dir=captions_dir, + with_audio=with_audio, + frame_rate=float(req.preprocess.frame_rate) if req.preprocess.frame_rate else None, + ) + emit({"type": "phase_progress", "phase": "preprocessing", "message": "Preprocess complete"}) + + _check_cancel(should_cancel) + status["phase"] = "training" + status["total_steps"] = int(req.train.steps) + emit({"type": "phase_started", "phase": "training", "message": "Training LoRA…"}) + + from ltx_trainer_mlx.trainer import LtxvTrainer + + config = build_trainer_config(req, paths=paths) + paths.config.write_text(yaml.safe_dump(config.model_dump(mode="json")), encoding="utf-8") + + train_t0 = time.time() + last_metrics: dict[str, float] = {} + + def on_metrics(m: dict[str, float]) -> None: + last_metrics.update(m) + + def step_callback(step: int, total: int, validation_paths: list) -> None: + _check_cancel(should_cancel) + elapsed = max(time.time() - train_t0, 1e-6) + eta_s = (elapsed / max(step, 1)) * max(total - step, 0) + payload: dict[str, Any] = { + "type": "train_step", + "phase": "training", + "step": int(step), + "total_steps": int(total), + "eta_s": round(eta_s, 1), + } + if last_metrics: + payload.update(last_metrics) + emit(payload) + if validation_paths: + rels = [] + for vp in validation_paths: + p = Path(vp) + try: + rel = p.relative_to(paths.outputs) + except ValueError: + rel = p.name + rels.append( + { + "step": int(step), + "filename": str(rel), + "url": f"/api/train/jobs/{job_id}/artifacts/{rel.as_posix()}", + } + ) + status.setdefault("validation_clips", []).extend(rels) + emit({"type": "train_validation", "step": int(step), "videos": rels}) + + with _metrics_hook(on_metrics): + trainer = LtxvTrainer(config) + saved_path, stats = trainer.train( + disable_progress_bars=True, + step_callback=step_callback, + ) + + lora_path = Path(saved_path) + artifact_url = f"/api/train/jobs/{job_id}/artifacts/{lora_path.name}" + status["phase"] = "done" + status["artifact_lora"] = str(lora_path) + status["artifact_url"] = artifact_url + status["stats"] = stats.model_dump() if hasattr(stats, "model_dump") else dict(stats) + save_status(output_dir, job_id, status) + emit( + { + "type": "job_done", + "artifact_url": artifact_url, + "artifact_name": lora_path.name, + "stats": status["stats"], + } + ) + return status + + except TrainingCancelledError: + status["phase"] = "cancelled" + status["error"] = "Cancelled" + save_status(output_dir, job_id, status) + emit({"type": "error", "phase": status.get("phase"), "message": "Cancelled"}) + raise + except Exception as exc: + log.exception("Train job %s failed", job_id) + status["phase"] = "failed" + status["error"] = str(exc) + save_status(output_dir, job_id, status) + emit({"type": "error", "phase": status.get("phase"), "message": str(exc)}) + raise diff --git a/requirements.txt b/requirements.txt index b0f3f89..82dde9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,7 @@ mcp>=1.0.0 # # Tag must match ltx_mlx_backend.LTX2_MLX_GIT_TAG (v0.14.12). # Upstream: https://github.com/dgrauet/ltx-2-mlx +# +# Optional trainer (training branch / Web UI /train): +# uv pip install \ +# "ltx-trainer-mlx @ git+https://github.com/dgrauet/ltx-2-mlx.git@v0.14.12#subdirectory=packages/ltx-trainer" diff --git a/train_configs/lora_av.yaml b/train_configs/lora_av.yaml new file mode 100644 index 0000000..dc879f6 --- /dev/null +++ b/train_configs/lora_av.yaml @@ -0,0 +1,53 @@ +# Hyperparameters only — paths injected at runtime (see lora_t2v.yaml header). +model: + transformer_file: transformer-dev.safetensors + text_encoder_path: mlx-community/gemma-3-12b-it-4bit + training_mode: lora + +lora: + rank: 32 + alpha: 32 + dropout: 0.0 + target_modules: + - audio_attn1 + - audio_attn2 + - video_to_audio_attn + +optimization: + learning_rate: 1.5e-4 + steps: 2000 + batch_size: 1 + gradient_accumulation_steps: 1 + max_grad_norm: 1.0 + weight_decay: 0.0 + enable_gradient_checkpointing: true + scheduler_type: linear + scheduler_params: + start_factor: 1.0 + end_factor: 0.1 + +training_strategy: + name: text_to_video + generate_audio: true + +flow_matching: + timestep_sampling_mode: shifted_logit_normal + +validation: + prompts: + - "a person speaking softly close to a microphone, intimate ASMR" + video_dims: [192, 192, 97] + frame_rate: 24.0 + inference_steps: 8 + interval: 200 + guidance_scale: 4.0 + stg_scale: 0.0 + seed: 42 + generate_audio: true + skip_initial_validation: true + +checkpoints: + interval: 100 + keep_last_n: 10 + +seed: 42 diff --git a/train_configs/lora_t2v.yaml b/train_configs/lora_t2v.yaml new file mode 100644 index 0000000..eaf724d --- /dev/null +++ b/train_configs/lora_t2v.yaml @@ -0,0 +1,56 @@ +# Hyperparameters only — paths are injected at runtime: +# model_path → local dir / $VIDEOFENTANYL_MODELS / HF hub cache (resolve_mlx_weights_directory) +# preprocessed → /train//preprocessed +# output_dir → /train//outputs +model: + text_encoder_path: mlx-community/gemma-3-12b-it-4bit + training_mode: lora + +lora: + rank: 64 + alpha: 64 + dropout: 0.0 + target_modules: + - to_k + - to_q + - to_v + - to_out.0 + +optimization: + learning_rate: 5.0e-4 + steps: 2000 + batch_size: 1 + gradient_accumulation_steps: 1 + max_grad_norm: 1.0 + weight_decay: 0.0 + enable_gradient_checkpointing: false + scheduler_type: linear + scheduler_params: + start_factor: 1.0 + end_factor: 0.1 + +training_strategy: + name: text_to_video + generate_audio: false + +flow_matching: + timestep_sampling_mode: shifted_logit_normal + +validation: + prompts: + - "a cinematic shot of a cat walking through a sunlit garden" + video_dims: [704, 480, 25] + frame_rate: 24.0 + inference_steps: 8 + interval: 500 + guidance_scale: 4.0 + stg_scale: 0.0 + seed: 42 + generate_audio: false + skip_initial_validation: true + +checkpoints: + interval: 500 + keep_last_n: 3 + +seed: 42 From 1ee5b5ac0393ff862324ffda5097cc99fe3c8229 Mon Sep 17 00:00:00 2001 From: Lorenzo Mangani Date: Thu, 18 Jun 2026 20:18:07 +0200 Subject: [PATCH 4/7] Add /train wizard UI and training job API with live SSE progress. Wire web_train routes and worker queue into the Web UI, add a four-step LoRA training flow with preset picker, uploads, hyperparameters, and validation gallery plus one-click register to the generation library. Co-authored-by: Cursor --- web/package-lock.json | 60 +++- web/package.json | 3 +- web/src/App.tsx | 12 +- web/src/Layout.tsx | 25 ++ web/src/TrainPage.tsx | 764 ++++++++++++++++++++++++++++++++++++++++++ web/src/api/train.ts | 113 +++++++ web/src/index.css | 648 +++++++++++++++++++++++++++++++++++ web/src/main.tsx | 12 +- web/src/types.ts | 44 +++ web_train.py | 488 +++++++++++++++++++++++++++ web_ui.py | 15 +- 11 files changed, 2172 insertions(+), 12 deletions(-) create mode 100644 web/src/Layout.tsx create mode 100644 web/src/TrainPage.tsx create mode 100644 web/src/api/train.ts create mode 100644 web_train.py diff --git a/web/package-lock.json b/web/package-lock.json index f728aab..d757da3 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -9,7 +9,8 @@ "version": "0.1.0", "dependencies": { "react": "^19.1.0", - "react-dom": "^19.1.0" + "react-dom": "^19.1.0", + "react-router-dom": "^7.6.2" }, "devDependencies": { "@types/react": "^19.1.0", @@ -1318,6 +1319,19 @@ "dev": true, "license": "MIT" }, + "node_modules/cookie": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-1.1.1.tgz", + "integrity": "sha512-ei8Aos7ja0weRpFzJnEA9UHJ/7XQmqglbRwnf2ATjcB9Wq874VKH9kfjjirM6UhU2/E5fFYadylyhFldcqSidQ==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/csstype": { "version": "3.2.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz", @@ -1604,6 +1618,44 @@ "node": ">=0.10.0" } }, + "node_modules/react-router": { + "version": "7.18.0", + "resolved": "https://registry.npmjs.org/react-router/-/react-router-7.18.0.tgz", + "integrity": "sha512-pTTGt8J+ji1NOmYnjzT+bAJy/1zD+Jp4ziO6cL7T3ZLvXKtusO7BpFqlRXitqpcPVqllsIXFHRMt+2/k3Xn6HQ==", + "license": "MIT", + "dependencies": { + "cookie": "^1.0.1", + "set-cookie-parser": "^2.6.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "react": ">=18", + "react-dom": ">=18" + }, + "peerDependenciesMeta": { + "react-dom": { + "optional": true + } + } + }, + "node_modules/react-router-dom": { + "version": "7.18.0", + "resolved": "https://registry.npmjs.org/react-router-dom/-/react-router-dom-7.18.0.tgz", + "integrity": "sha512-Fi0yY6kgtKae/Th2xibdWK0KSdYZ4B53Gyf6wRtomOKWgpNm7H7+DyfDhncdz9FKbpS+1jmDhg3F4WoGJ+yFOA==", + "license": "MIT", + "dependencies": { + "react-router": "7.18.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "react": ">=18", + "react-dom": ">=18" + } + }, "node_modules/rollup": { "version": "4.62.0", "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.62.0.tgz", @@ -1665,6 +1717,12 @@ "semver": "bin/semver.js" } }, + "node_modules/set-cookie-parser": { + "version": "2.7.2", + "resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-2.7.2.tgz", + "integrity": "sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==", + "license": "MIT" + }, "node_modules/source-map-js": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", diff --git a/web/package.json b/web/package.json index c9e5fba..7ba7a84 100644 --- a/web/package.json +++ b/web/package.json @@ -10,7 +10,8 @@ }, "dependencies": { "react": "^19.1.0", - "react-dom": "^19.1.0" + "react-dom": "^19.1.0", + "react-router-dom": "^7.6.2" }, "devDependencies": { "@types/react": "^19.1.0", diff --git a/web/src/App.tsx b/web/src/App.tsx index 2121997..8f642e6 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -1398,12 +1398,8 @@ export default function App() { }, [prompt, fitPromptHeight]); return ( -
-
-
- LTX-WS - Videofentanyl -
+ <> +
-
+
@@ -2227,6 +2223,6 @@ export default function App() {
- + ); } diff --git a/web/src/Layout.tsx b/web/src/Layout.tsx new file mode 100644 index 0000000..1eccf02 --- /dev/null +++ b/web/src/Layout.tsx @@ -0,0 +1,25 @@ +import { NavLink, Outlet } from "react-router-dom"; + +export default function Layout() { + return ( +
+
+
+
+ LTX-WS + Videofentanyl +
+ +
+
+ +
+ ); +} diff --git a/web/src/TrainPage.tsx b/web/src/TrainPage.tsx new file mode 100644 index 0000000..7595a22 --- /dev/null +++ b/web/src/TrainPage.tsx @@ -0,0 +1,764 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import type { Config, TrainHealth, TrainJob, TrainPreset } from "./types"; +import { + cancelTrainJob, + createTrainJob, + fetchTrainHealth, + fetchTrainJobs, + fetchTrainPresets, + registerTrainedLora, + subscribeTrainJob, + type TrainManifest, +} from "./api/train"; + +type WizardStep = "dataset" | "preprocess" | "train" | "runs"; + +const STEPS: { id: WizardStep; label: string; hint: string }[] = [ + { id: "dataset", label: "Dataset", hint: "Videos & captions" }, + { id: "preprocess", label: "Preprocess", hint: "Resolution & frames" }, + { id: "train", label: "Train", hint: "Hyperparameters" }, + { id: "runs", label: "Runs", hint: "Progress & output" }, +]; + +function formatEta(seconds?: number): string { + if (seconds == null || !Number.isFinite(seconds)) return "—"; + const s = Math.max(0, Math.round(seconds)); + const h = Math.floor(s / 3600); + const m = Math.floor((s % 3600) / 60); + const sec = s % 60; + if (h > 0) return `${h}h ${m}m`; + if (m > 0) return `${m}m ${sec}s`; + return `${sec}s`; +} + +function phaseLabel(phase?: string): string { + switch (phase) { + case "slicing": + return "Slicing videos"; + case "preprocessing": + return "Preprocessing latents"; + case "training": + return "Training LoRA"; + case "done": + return "Complete"; + case "failed": + return "Failed"; + case "cancelled": + return "Cancelled"; + default: + return phase || "Queued"; + } +} + +export default function TrainPage() { + const [step, setStep] = useState("dataset"); + const [health, setHealth] = useState(null); + const [presets, setPresets] = useState([]); + const [config, setConfig] = useState(null); + const [jobs, setJobs] = useState([]); + const [activeJobId, setActiveJobId] = useState(null); + const [error, setError] = useState(null); + const [submitting, setSubmitting] = useState(false); + const [registering, setRegistering] = useState(false); + + const [name, setName] = useState("My LoRA"); + const [preset, setPreset] = useState("t2v"); + const [videos, setVideos] = useState([]); + const [dragOver, setDragOver] = useState(false); + const fileInputRef = useRef(null); + + const [sliceEnabled, setSliceEnabled] = useState(false); + const [sliceInterval, setSliceInterval] = useState(4); + const [sliceRes, setSliceRes] = useState("384x384"); + const [sliceFps, setSliceFps] = useState(24); + const [sliceFit, setSliceFit] = useState("crop"); + const [captionTemplate, setCaptionTemplate] = useState(""); + + const [width, setWidth] = useState(704); + const [height, setHeight] = useState(480); + const [maxFrames, setMaxFrames] = useState(97); + const [withAudio, setWithAudio] = useState(false); + + const [steps, setSteps] = useState(2000); + const [rank, setRank] = useState(64); + const [learningRate, setLearningRate] = useState(0.0005); + const [validationPrompts, setValidationPrompts] = useState( + "a cinematic landscape at sunset\na person walking through neon rain", + ); + const [validationInterval, setValidationInterval] = useState(500); + const [checkpointInterval, setCheckpointInterval] = useState(500); + const [lowRam, setLowRam] = useState(false); + const [seed, setSeed] = useState(42); + + const activeJob = useMemo( + () => jobs.find((j) => j.id === activeJobId) ?? null, + [jobs, activeJobId], + ); + + const selectedPreset = useMemo( + () => presets.find((p) => p.id === preset) ?? presets[0], + [presets, preset], + ); + + const refreshJobs = useCallback(async () => { + try { + const list = await fetchTrainJobs(); + setJobs(list); + } catch { + /* ignore */ + } + }, []); + + useEffect(() => { + fetch("/api/config") + .then((r) => r.json()) + .then((c: Config) => setConfig(c)) + .catch(() => {}); + fetchTrainHealth().then(setHealth).catch(() => {}); + fetchTrainPresets().then(setPresets).catch(() => {}); + refreshJobs(); + }, [refreshJobs]); + + useEffect(() => { + if (!selectedPreset) return; + setWithAudio(selectedPreset.with_audio); + setLowRam(selectedPreset.low_ram_default); + }, [selectedPreset?.id]); + + const manifest = useMemo((): TrainManifest => { + const prompts = validationPrompts + .split("\n") + .map((p) => p.trim()) + .filter(Boolean); + return { + name, + preset, + model_id: config?.preferred_model || config?.active_model || "auto", + slice: { + enabled: sliceEnabled, + interval: sliceInterval, + res: sliceRes, + fps: sliceFps, + fit: sliceFit, + caption_template: captionTemplate.trim() || undefined, + }, + preprocess: { + width, + height, + max_frames: maxFrames, + with_audio: withAudio, + frame_rate: 24, + }, + train: { + steps, + rank, + learning_rate: learningRate, + validation_prompts: prompts.length ? prompts : ["a cinematic landscape at sunset"], + validation_interval: validationInterval, + checkpoint_interval: checkpointInterval, + low_ram: lowRam, + seed, + }, + }; + }, [ + name, + preset, + config, + sliceEnabled, + sliceInterval, + sliceRes, + sliceFps, + sliceFit, + captionTemplate, + width, + height, + maxFrames, + withAudio, + steps, + rank, + learningRate, + validationPrompts, + validationInterval, + checkpointInterval, + lowRam, + seed, + ]); + + const updateJob = useCallback((jobId: string, patch: Partial) => { + setJobs((prev) => + prev.map((j) => (j.id === jobId ? { ...j, ...patch } : j)), + ); + }, []); + + useEffect(() => { + if (!activeJobId) return; + const unsub = subscribeTrainJob(activeJobId, (event) => { + const type = String(event.type || ""); + if (type === "phase_started") { + updateJob(activeJobId, { phase: String(event.phase || ""), status: "running" }); + } else if (type === "train_step") { + updateJob(activeJobId, { + step: Number(event.step) || 0, + total_steps: Number(event.total_steps) || 0, + loss: event.loss != null ? Number(event.loss) : undefined, + lr: event.lr != null ? Number(event.lr) : undefined, + eta_s: event.eta_s != null ? Number(event.eta_s) : undefined, + phase: "training", + status: "running", + }); + } else if (type === "train_validation") { + const videos = (event.videos as TrainJob["validation_clips"]) || []; + setJobs((prev) => + prev.map((j) => + j.id === activeJobId + ? { ...j, validation_clips: [...(j.validation_clips || []), ...videos] } + : j, + ), + ); + } else if (type === "job_done") { + updateJob(activeJobId, { + status: "done", + phase: "done", + artifact_url: String(event.artifact_url || ""), + artifact_name: String(event.artifact_name || ""), + }); + } else if (type === "error") { + updateJob(activeJobId, { + status: event.message === "Cancelled" ? "cancelled" : "failed", + phase: event.message === "Cancelled" ? "cancelled" : "failed", + error: String(event.message || "Error"), + }); + } else if (type === "snapshot" && event.job) { + const snap = event.job as TrainJob; + updateJob(activeJobId, snap); + } + if (type === "job_complete") { + refreshJobs(); + } + }); + return unsub; + }, [activeJobId, updateJob, refreshJobs]); + + function addFiles(fileList: FileList | File[]) { + const incoming = Array.from(fileList).filter((f) => + /\.(mp4|mov|avi|mkv|webm|txt)$/i.test(f.name), + ); + if (!incoming.length) return; + setVideos((prev) => { + const names = new Set(prev.map((f) => f.name)); + const merged = [...prev]; + for (const f of incoming) { + if (!names.has(f.name)) merged.push(f); + } + return merged; + }); + } + + async function startTraining() { + setError(null); + const videoFiles = videos.filter((f) => !f.name.toLowerCase().endsWith(".txt")); + if (!videoFiles.length) { + setError("Add at least one video file (.mp4, .mov, …)."); + setStep("dataset"); + return; + } + if (!health?.trainer_installed) { + setError("Install ltx-trainer-mlx on the server (see install hint below)."); + return; + } + if (health.generation_active) { + setError("Wait for the current generation to finish before training."); + return; + } + setSubmitting(true); + try { + const result = await createTrainJob(manifest, videos); + const job: TrainJob = { + id: result.job_id, + name: result.name, + preset: result.preset, + status: "queued", + phase: "queued", + created_at: new Date().toISOString(), + total_steps: steps, + validation_clips: [], + }; + setJobs((prev) => [job, ...prev]); + setActiveJobId(result.job_id); + setStep("runs"); + } catch (exc) { + setError(exc instanceof Error ? exc.message : String(exc)); + } finally { + setSubmitting(false); + } + } + + async function handleCancel() { + if (!activeJobId) return; + try { + await cancelTrainJob(activeJobId); + updateJob(activeJobId, { status: "cancelled", phase: "cancelled" }); + } catch (exc) { + setError(exc instanceof Error ? exc.message : String(exc)); + } + } + + async function handleRegister() { + if (!activeJobId || !activeJob) return; + setRegistering(true); + setError(null); + try { + const result = await registerTrainedLora(activeJobId, name, 1.0); + updateJob(activeJobId, { registered_lora_id: result.id }); + } catch (exc) { + setError(exc instanceof Error ? exc.message : String(exc)); + } finally { + setRegistering(false); + } + } + + const trainProgress = + activeJob?.total_steps && activeJob.total_steps > 0 + ? Math.min(100, ((activeJob.step || 0) / activeJob.total_steps) * 100) + : 0; + + const videoCount = videos.filter((f) => !f.name.toLowerCase().endsWith(".txt")).length; + const captionCount = videos.filter((f) => f.name.toLowerCase().endsWith(".txt")).length; + + return ( +
+
+
+

Train a LoRA

+

+ Upload clips, preprocess latents, and fine-tune a style LoRA for generation — all on your Mac. +

+
+
+ + {health?.trainer_installed ? "Trainer ready" : "Trainer not installed"} + {health?.training_active && Training} + {health?.generation_active && Gen active} +
+
+ + {!health?.trainer_installed && health?.install_hint && ( +
+ Install training support + {health.install_hint} +
+ )} + + {error && ( +
+ {error} +
+ )} + +
+ + +
+ {step === "dataset" && ( +
+

Dataset

+

+ Drop training videos here. Optional .txt caption files with matching names are used when + slicing is off. +

+ + + +
+ {presets.map((p) => ( + + ))} +
+ +
{ + e.preventDefault(); + setDragOver(true); + }} + onDragLeave={() => setDragOver(false)} + onDrop={(e) => { + e.preventDefault(); + setDragOver(false); + if (e.dataTransfer.files.length) addFiles(e.dataTransfer.files); + }} + onClick={() => fileInputRef.current?.click()} + role="button" + tabIndex={0} + onKeyDown={(e) => { + if (e.key === "Enter" || e.key === " ") fileInputRef.current?.click(); + }} + > + e.target.files && addFiles(e.target.files)} + /> +
+
Drop videos or click to browse
+
+ {videoCount} video{videoCount !== 1 ? "s" : ""} + {captionCount > 0 ? ` · ${captionCount} caption file${captionCount !== 1 ? "s" : ""}` : ""} +
+
+ + {videos.length > 0 && ( +
    + {videos.map((f) => ( +
  • + {f.name} + {(f.size / 1024 / 1024).toFixed(1)} MB + +
  • + ))} +
+ )} + +
+ Auto-slice long videos +
+ + {sliceEnabled && ( +
+ + + + + +
+ )} +
+
+ +
+ +
+
+ )} + + {step === "preprocess" && ( +
+

Preprocess

+

+ Latents are encoded at this resolution. Frames are rounded to valid LTX lengths (8k+1). +

+
+ + + + +
+
+ + +
+
+ )} + + {step === "train" && ( +
+

Training

+

+ Typical runs use 1k–3k steps. Validation clips appear during training so you can judge quality early. +

+
+ + + + + + + +