diff --git a/assets/test/reference_frames.npy b/assets/test/reference_frames.npy index f8cfff8..5e274bc 100644 Binary files a/assets/test/reference_frames.npy and b/assets/test/reference_frames.npy differ diff --git a/configs/train.yaml b/configs/train.yaml index ebdd128..91ddb57 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -6,6 +6,7 @@ hydra: exp_name: voxtream_train model_repo: herimor/voxtream2 dataset_repo: herimor/voxtream2-train +dataset_revision: null dataset_base_dir: null dep_former_name: dep_former_csm.safetensors dep_former_weight_path: null @@ -21,6 +22,7 @@ precision: bf16-mixed gradient_clip_val: 0.5 seed: 42 num_workers: 8 +prefetch_factor: 2 log_every_n_steps: 50 gpus: -1 logging_interval: step diff --git a/pyproject.toml b/pyproject.toml index 199bd3f..fb28d1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,17 @@ build-backend = "setuptools.build_meta" [tool.setuptools] license-files = ["LICENSE"] +include-package-data = true + +[tool.setuptools.package-data] +voxtream = [ + "configs/*.json", + "assets/*.json", + "assets/audio/*.wav", + "assets/benchmark/*.csv", + "assets/test/*.wav", + "assets/test/*.npy", +] [project] name = "voxtream" @@ -55,6 +66,9 @@ dependencies = [ ] [project.optional-dependencies] +server = ["fastapi", "uvicorn"] +client = ["websockets", "sounddevice"] +benchmark = ["pandas", "tqdm"] dev = ["black", "isort", "flake8", "mypy", "pytest"] [project.urls] diff --git a/tests/test_config_loading.py b/tests/test_config_loading.py new file mode 100644 index 0000000..1299d9c --- /dev/null +++ b/tests/test_config_loading.py @@ -0,0 +1,42 @@ +import json + +import pytest + +from voxtream.config import load_generator_config, load_speaking_rate_config + + +def test_load_generator_config_from_default_package_or_repo_path(): + config = load_generator_config() + + assert config.mimi_sr == 24000 + assert config.model_repo == "herimor/voxtream2" + + +def test_load_generator_config_rejects_unknown_field(tmp_path): + config_path = tmp_path / "generator.json" + config = load_generator_config() + payload = config.__dict__.copy() + payload["unexpected"] = True + config_path.write_text(json.dumps(payload)) + + with pytest.raises(ValueError, match="unknown fields"): + load_generator_config(config_path) + + +def test_load_generator_config_rejects_invalid_ranges(tmp_path): + config_path = tmp_path / "generator.json" + config = load_generator_config() + payload = config.__dict__.copy() + payload["mimi_frame_ms"] = 0 + config_path.write_text(json.dumps(payload)) + + with pytest.raises(ValueError, match="mimi_frame_ms"): + load_generator_config(config_path) + + +def test_load_speaking_rate_config_validates_shape(tmp_path): + config_path = tmp_path / "speaking_rate.json" + config_path.write_text(json.dumps({"1": {"duration_state": [], "weight": 1.0, "cfg_gamma": 1.0}})) + + with pytest.raises(ValueError, match="duration_state"): + load_speaking_rate_config(config_path) diff --git a/tests/test_dataset_loading.py b/tests/test_dataset_loading.py new file mode 100644 index 0000000..24698da --- /dev/null +++ b/tests/test_dataset_loading.py @@ -0,0 +1,38 @@ +import numpy as np + +from voxtream.dataset import TrainDataset + + +def test_train_dataset_loads_shards_without_pickle(monkeypatch, tmp_path): + calls = [] + + def fake_load(path, allow_pickle=False): + calls.append((path, allow_pickle)) + return np.array([1]) + + monkeypatch.setattr(np, "load", fake_load) + monkeypatch.setattr(TrainDataset, "__len__", lambda self: 1) + + dataset = TrainDataset.__new__(TrainDataset) + TrainDataset.__init__( + dataset, + base_dir=tmp_path, + datasets={"one": {"audio_codes": "a.npy"}, "two": {"audio_codes": "b.npy"}}, + phone_vocab_size=10, + audio_vocab_size=10, + audio_pad_size=0, + num_codebooks=2, + audio_delay_frames=1, + dtype="int64", + audio_window_size=1, + pad_len=1, + semantic_label_pad=2, + num_phones_per_frame=2, + phoneme_index_map={"0": 1}, + cfg_prob=0.0, + prompt_length_sec=[1, 2], + mimi_tps=12.5, + ) + + assert all(allow_pickle is False for _, allow_pickle in calls) + assert dataset.audio_codes.tolist() == [1, 1] diff --git a/tests/test_model_pool.py b/tests/test_model_pool.py new file mode 100644 index 0000000..7dbdb23 --- /dev/null +++ b/tests/test_model_pool.py @@ -0,0 +1,11 @@ +def test_model_pool_factories_return_distinct_instances(): + pytest = __import__("pytest") + pytest.importorskip("torch") + pytest.importorskip("torchtune") + + from voxtream.utils.model import MODEL_POOL + + first = MODEL_POOL["phone_former"]() + second = MODEL_POOL["phone_former"]() + + assert first is not second diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py new file mode 100644 index 0000000..5c03040 --- /dev/null +++ b/tests/test_prompt_cache.py @@ -0,0 +1,43 @@ +import numpy as np +import pytest + +from voxtream.config import load_generator_config +from voxtream.utils.generator.prompt import ( + _load_prompt_cache, + _prompt_cache_metadata, + _prompt_cache_path, + _save_prompt_cache, +) + + +def test_prompt_cache_uses_npz_without_pickle(tmp_path): + torch = pytest.importorskip("torch") + prompt = tmp_path / "voice.wav" + prompt.write_bytes(b"RIFF") + config = load_generator_config() + cache_path = _prompt_cache_path(prompt) + metadata = _prompt_cache_metadata(prompt, config, enhance_prompt=False, apply_vad=False) + audio_tokens = torch.zeros((1, 2, 3), dtype=torch.int64) + spk_embedding = torch.ones((1, 4), dtype=torch.float32) + + _save_prompt_cache(cache_path, metadata, audio_tokens, spk_embedding) + + with np.load(cache_path, allow_pickle=False) as data: + assert set(data.files) == {"metadata", "audio_tokens", "spk_embedding"} + + +def test_prompt_cache_rejects_stale_metadata(tmp_path): + torch = pytest.importorskip("torch") + prompt = tmp_path / "voice.wav" + prompt.write_bytes(b"RIFF") + config = load_generator_config() + cache_path = _prompt_cache_path(prompt) + metadata = _prompt_cache_metadata(prompt, config, enhance_prompt=False, apply_vad=False) + audio_tokens = torch.zeros((1, 2, 3), dtype=torch.int64) + spk_embedding = torch.ones((1, 4), dtype=torch.float32) + _save_prompt_cache(cache_path, metadata, audio_tokens, spk_embedding) + + stale_metadata = dict(metadata) + stale_metadata["cache_key"] = "stale" + + assert _load_prompt_cache(cache_path, stale_metadata) is None diff --git a/tests/test_run_output_regression.py b/tests/test_run_output_regression.py index 83b1715..4035a0f 100644 --- a/tests/test_run_output_regression.py +++ b/tests/test_run_output_regression.py @@ -25,12 +25,24 @@ def test_run_main_output_matches_reference(monkeypatch, tmp_path): captured_write = {} - def fake_sf_write(output_path, data, samplerate): - captured_write["path"] = str(output_path) - captured_write["data"] = np.asarray(data) - captured_write["samplerate"] = samplerate + class FakeSoundFile: + def __init__(self, output_path, mode, samplerate, channels): + captured_write["path"] = str(output_path) + captured_write["mode"] = mode + captured_write["samplerate"] = samplerate + captured_write["channels"] = channels + captured_write["chunks"] = [] - monkeypatch.setattr(run.sf, "write", fake_sf_write) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + captured_write["data"] = np.concatenate(captured_write["chunks"]) + + def write(self, data): + captured_write["chunks"].append(np.asarray(data)) + + monkeypatch.setattr(run.sf, "SoundFile", FakeSoundFile) output_path = tmp_path / "voxtream_run_ref.wav" monkeypatch.setattr( @@ -53,6 +65,8 @@ def fake_sf_write(output_path, data, samplerate): run.main() assert captured_write["path"] == str(output_path) + assert captured_write["mode"] == "w" + assert captured_write["channels"] == 1 assert captured_write["data"].shape == expected_audio.shape np.testing.assert_allclose( captured_write["data"], expected_audio, rtol=1e-5, atol=1e-6 diff --git a/tests/test_server_prompt_inputs.py b/tests/test_server_prompt_inputs.py new file mode 100644 index 0000000..aab2f18 --- /dev/null +++ b/tests/test_server_prompt_inputs.py @@ -0,0 +1,58 @@ +import base64 +from pathlib import Path + +import pytest + +from voxtream.server import ( + MAX_PROMPT_AUDIO_BYTES, + _b64_to_bytes, + _ensure_prompt_audio_file, + _validate_prompt_audio_path, +) + + +def test_validate_prompt_audio_path_rejects_paths_outside_root(tmp_path): + prompt_root = tmp_path / "allowed" + prompt_root.mkdir() + outside = tmp_path / "outside.wav" + outside.write_bytes(b"RIFF") + + with pytest.raises(ValueError, match="inside"): + _validate_prompt_audio_path(str(outside), prompt_root) + + +def test_validate_prompt_audio_path_accepts_file_inside_root(tmp_path): + prompt_root = tmp_path / "allowed" + prompt_root.mkdir() + prompt = prompt_root / "voice.wav" + prompt.write_bytes(b"RIFF") + + assert _validate_prompt_audio_path(str(prompt), prompt_root) == prompt.resolve() + + +def test_b64_to_bytes_rejects_invalid_base64(): + with pytest.raises(ValueError, match="valid base64"): + _b64_to_bytes("not base64?") + + +def test_b64_to_bytes_rejects_large_payload(): + payload = base64.b64encode(b"x" * (MAX_PROMPT_AUDIO_BYTES + 1)).decode() + + with pytest.raises(ValueError, match="exceeds"): + _b64_to_bytes(payload) + + +def test_ensure_prompt_audio_file_creates_temp_for_base64_and_marks_temp(tmp_path): + prompt_path, is_temp = _ensure_prompt_audio_file( + None, + "data:audio/wav;base64," + base64.b64encode(b"RIFF").decode(), + prompt_root=tmp_path, + ) + + try: + assert is_temp is True + assert isinstance(prompt_path, Path) + assert prompt_path.exists() + assert prompt_path.read_bytes() == b"RIFF" + finally: + prompt_path.unlink(missing_ok=True) diff --git a/voxtream/app.py b/voxtream/app.py index aab437a..7084073 100644 --- a/voxtream/app.py +++ b/voxtream/app.py @@ -2,12 +2,18 @@ import json import uuid from pathlib import Path +from typing import Any, cast import gradio as gr import numpy as np import soundfile as sf -from voxtream.config import SpeechGeneratorConfig +from voxtream.config import ( + SpeechGeneratorConfig, + load_generator_config, + load_speaking_rate_config, + resolve_data_path, +) from voxtream.generator import SpeechGenerator from voxtream.utils.app import ( CUSTOM_CSS, @@ -24,12 +30,15 @@ render_text_progress, ) from voxtream.utils.generator import ( - existing_file, text_generator, ) from voxtream.utils.generator.text import build_text_progress_metadata +def _progress_frame(result: tuple[Any, ...]) -> tuple[np.ndarray[Any, Any], dict[str, Any]]: + return cast(np.ndarray[Any, Any], result[0]), cast(dict[str, Any], result[2]) + + def generation_button_updates(running: bool, paused: bool = False): if not running: return ( @@ -349,39 +358,36 @@ def main(): parser.add_argument( "-c", "--config", - type=existing_file, + type=Path, help="Path to the config file", default="configs/generator.json", ) parser.add_argument( "--app-config", - type=existing_file, + type=Path, help="Path to the app config file", default="configs/app.json", ) parser.add_argument( "--spk-rate-config", - type=existing_file, + type=Path, help="Path to the speaking rate config file", default="configs/speaking_rate.json", ) parser.add_argument( "--examples-config", - type=existing_file, + type=Path, help="Path to the examples config file", default="assets/examples.json", ) args = parser.parse_args() - with open(args.config) as f: - config = SpeechGeneratorConfig(**json.load(f)) + config = load_generator_config(args.config) + spk_rate_config = load_speaking_rate_config(args.spk_rate_config) - with open(args.spk_rate_config) as f: - spk_rate_config = json.load(f) + app_config = load_app_config(resolve_data_path(args.app_config, "configs/app.json")) - app_config = load_app_config(args.app_config) - - with open(args.examples_config) as f: + with resolve_data_path(args.examples_config, "assets/examples.json").open() as f: examples_config = json.load(f) demo_examples = examples_config.get("examples", []) @@ -457,53 +463,54 @@ def synthesize_fn( buffer = [] buffer_len = 0 - total_buffer = [] + file_path = f"/tmp/voxtream_{uuid.uuid4().hex}.wav" stopped = False - stream_iter = iter(stream) - while True: - if not generation_control.wait_if_paused(): - stopped = True - break - try: - frame, _, progress = next(stream_iter) - except StopIteration: - break - if generation_control.is_stopped(): - stopped = True - break - - buffer.append(frame) - total_buffer.append(frame) - buffer_len += frame.shape[0] - plot_update, text_update = visualization.update(progress) - - if buffer_len >= chunk_size: + with sf.SoundFile(file_path, "w", samplerate=config.mimi_sr, channels=1) as output_file: + stream_iter = iter(stream) + while True: + if not generation_control.wait_if_paused(): + stopped = True + break + try: + frame, progress = _progress_frame(next(stream_iter)) + except StopIteration: + break if generation_control.is_stopped(): stopped = True break - audio = np.concatenate(buffer) - stream_seq += 1 - yield ( - gr.update(), - gr.update(), - plot_update, - text_update, - render_audio_stream( - app_config, - session_id=stream_session_id, - seq=stream_seq, - sample_rate=config.mimi_sr, - audio=float32_to_int16(audio), - active=True, - ), - *generation_button_updates( - running=True, paused=generation_control.is_paused() - ), - ) - buffer = [] - buffer_len = 0 + output_file.write(frame) + buffer.append(frame) + buffer_len += frame.shape[0] + plot_update, text_update = visualization.update(progress) + + if buffer_len >= chunk_size: + if generation_control.is_stopped(): + stopped = True + break + audio = np.concatenate(buffer) + stream_seq += 1 + yield ( + gr.update(), + gr.update(), + plot_update, + text_update, + render_audio_stream( + app_config, + session_id=stream_session_id, + seq=stream_seq, + sample_rate=config.mimi_sr, + audio=float32_to_int16(audio), + active=True, + ), + *generation_button_updates( + running=True, paused=generation_control.is_paused() + ), + ) + + buffer = [] + buffer_len = 0 stopped = stopped or generation_control.is_stopped() if stopped and hasattr(stream, "close"): @@ -535,18 +542,7 @@ def synthesize_fn( ), ) - if len(total_buffer) > 0: - full_audio = np.concatenate(total_buffer) - nfade = min( - int(config.mimi_sr * app_config.fade_out_sec), full_audio.shape[0] - ) - if nfade > 0: - fade = np.linspace(1.0, 0.0, nfade, dtype=np.float32) - full_audio[-nfade:] *= fade - - file_path = f"/tmp/voxtream_{uuid.uuid4().hex}.wav" - sf.write(file_path, float32_to_int16(full_audio), config.mimi_sr) - + if Path(file_path).exists() and Path(file_path).stat().st_size > 0: speaking_rate_state.stop() generation_control.finish() yield ( diff --git a/voxtream/assets/audio/english_male.wav b/voxtream/assets/audio/english_male.wav new file mode 100644 index 0000000..1483c20 Binary files /dev/null and b/voxtream/assets/audio/english_male.wav differ diff --git a/voxtream/assets/benchmark/meta.csv b/voxtream/assets/benchmark/meta.csv new file mode 100644 index 0000000..ca2ab50 --- /dev/null +++ b/voxtream/assets/benchmark/meta.csv @@ -0,0 +1,12 @@ +prompt_audio,prompt_text,text +assets/benchmark/common_voice_en_10119832.wav,"We asked over twenty different people, and they all said it was his.",Get the trust fund to the bank early. +assets/benchmark/common_voice_en_10119832.wav,"We asked over twenty different people, and they all said it was his.",The stained glass offered a hypnotic atmosphere. +assets/benchmark/common_voice_en_103675.wav,I'm never more aware of a room's acoustics than when I'm trying to enjoy a snack I have no intention of sharing.,"One by one, the campfires were extinguished, and the oasis fell as quiet as the desert." +assets/benchmark/common_voice_en_103675.wav,I'm never more aware of a room's acoustics than when I'm trying to enjoy a snack I have no intention of sharing.,The boy knew the desert sensed his fear. +assets/benchmark/common_voice_en_10933823.wav,Sometimes I overthink things which leads me to postpone and ultimately never achieve the goal I had in mind.,"When it comes to the crunch, our company will become insolvent." +assets/benchmark/common_voice_en_10933823.wav,Sometimes I overthink things which leads me to postpone and ultimately never achieve the goal I had in mind.,The primary coil has fifty turns. +assets/benchmark/common_voice_en_120405.wav,He approached the mass and was surprised at the size and the shape.,I'm never more aware of a room's acoustics than when I'm trying to enjoy a snack I have no intention of sharing. +assets/benchmark/common_voice_en_120405.wav,He approached the mass and was surprised at the size and the shape.,The only shadow was that of the few scattered pine trees. +assets/benchmark/common_voice_en_1205005.wav,"Roaming endlessly around the park, she wants to go home.",The work of the tailor is seen on each side. +assets/benchmark/common_voice_en_1205005.wav,"Roaming endlessly around the park, she wants to go home.",NASA plans to launch the rocket tomorrow. +assets/benchmark/common_voice_en_123125.wav,"There's no danger, the boy said, when they had moved on past the encampment.","After all, who doesn’t want to overcome new challenges and achieve great heights?" diff --git a/voxtream/assets/examples.json b/voxtream/assets/examples.json new file mode 100644 index 0000000..d3ff3c3 --- /dev/null +++ b/voxtream/assets/examples.json @@ -0,0 +1,52 @@ +{ + "examples": [ + [ + "assets/audio/english_male.wav", + "Full stream text-to-speech (TTS) for interactive systems must start speaking with minimal delay while remaining controllable as text arrives incrementally. The voice should keep a natural rhythm as each new phrase becomes available. This helps assistants respond quickly while still sounding calm and clear. Careful streaming design makes spoken interaction feel more direct and responsive. It also reduces long silent gaps during complex replies. The listener hears progress while the system continues planning the next phrase." + ], + [ + "assets/audio/english_female.wav", + "We present VoXtream2, a zero shot full stream text-to-speech (TTS) model with dynamic speaking rate control that can be updated during an utterance. The system can shift its pace while speech is already being produced. This allows a speaker to slow down for difficult content or move faster through simple phrases. Flexible control supports more expressive and useful speech generation. It also lets applications adapt delivery to user attention and context. The same voice can sound measured, concise, or relaxed as needed." + ], + [ + "assets/audio/chinese_female.wav", + "VoXtream2 combines distribution matching over duration states with classifier free guidance across conditioning signals to improve controllability and synthesis quality. These methods help the model follow timing instructions more reliably. They also preserve natural voice quality when several controls are active. The result is speech that can be guided without sounding rigid or unstable. Better alignment between text, timing, and voice style makes each output easier to shape." + ], + [ + "assets/audio/hindi_male.wav", + "Prompt text masking enables textless audio prompting, removing the need for prompt transcription. A user can provide a short voice reference without writing down what was said. This makes voice adaptation easier when transcripts are missing or expensive to prepare. The model can still learn useful speaker cues from the audio prompt. It can focus on tone, accent, and speaking style while ignoring unavailable words. This lowers setup effort for demos, research samples, and personal voice interfaces." + ], + [ + "assets/audio/spanish_male.wav", + "Across standard zero shot benchmarks and a dedicated speaking rate test set, VoXtream2 achieves competitive objective and subjective results against public baselines. The evaluation covers both measured accuracy and human listening preference. Strong results suggest that streaming control does not require a large loss in quality. This makes the approach practical for real conversational systems. Consistent performance across tests gives developers more confidence in deployment." + ], + [ + "assets/audio/arabic_female.wav", + "In full stream mode, it runs 4 times faster than real time with 74 ms first packet latency on a consumer graphics processor. Low latency helps the system begin speaking before the full response is complete. Faster synthesis also leaves more room for other application work. These properties are important for smooth real time dialogue. Efficient generation can support busy products without requiring unusual hardware. It also makes testing easier because responses arrive quickly during iteration." + ], + [ + "assets/audio/french_female.wav", + "It has long been argued that conversational agents must be able to generate speech incrementally. Human conversation often depends on quick turns and partial understanding. An agent that waits too long can make the exchange feel broken or unnatural. Incremental generation supports more fluid spoken interaction. It lets a system begin with a confident phrase while later content is still forming. This behavior can make spoken assistants feel attentive, present, and easier to interrupt." + ], + [ + "assets/audio/japanese_male.wav", + "Japanese audio recording can be used to prepare the examples list for a future voice sample. The entry keeps the same format as the other language examples. English text makes the placeholder clear for contributors who review the file. The audio path can be updated once the final sample is created. The placeholder supports early testing of selection flows and ordering. It also makes room for future validation of playback, captions, and voice metadata. Reviewers can confirm the option appears correctly." + ], + [ + "assets/audio/russian_female.wav", + "Recent progress in neural text-to-speech (TTS) synthesis has led to highly natural and intelligible speech generation. Modern models can produce voices with clear pronunciation and expressive prosody. These gains make synthetic speech useful in more demanding interactive settings. The next challenge is to keep that quality while adding fine grained control. Users expect a generated voice to remain stable when speed, style, or prompting changes. Robust models must balance realism, latency, and steering." + ], + [ + "assets/audio/swedish_female.wav", + "However, most contemporary systems implicitly assume that speaking rate is static across an utterance, typically allowing only coarse, global control over speed. Real speakers often vary their pace within a single response. They may pause before important words or speed through familiar details. A useful speech system should support this kind of local timing control. This makes explanations clearer and keeps long replies from feeling flat. Local control also helps match emphasis to meaning." + ], + [ + "assets/audio/portuguese_male.wav", + "Portuguese audio prompt can be used to test how the interface presents another language option. The example text stays in English while the entry reserves space for a future recording. This keeps the data structure ready for multilingual expansion. A final voice sample can later replace the placeholder path. The placeholder also helps verify menus, labels, and playback behavior before the asset exists. It gives reviewers a clear signal that Portuguese support is planned but not complete." + ], + [ + "assets/audio/german_male.wav", + "German acoustic prompt can be used to check language selection and example ordering. The English text describes the expected role of the entry without adding translated content. This makes the placeholder easy to identify during development. A complete recording can be added when the voice asset is available. The entry can also reveal layout issues in lists that include many languages. It keeps the example set balanced while the final German sample is prepared. Reviewers can test the flow early." + ] + ] +} diff --git a/voxtream/assets/test/english_male.wav b/voxtream/assets/test/english_male.wav new file mode 100644 index 0000000..1b68e42 Binary files /dev/null and b/voxtream/assets/test/english_male.wav differ diff --git a/voxtream/assets/test/reference_frames.npy b/voxtream/assets/test/reference_frames.npy new file mode 100644 index 0000000..5e274bc Binary files /dev/null and b/voxtream/assets/test/reference_frames.npy differ diff --git a/voxtream/config.py b/voxtream/config.py index d5d036c..f1c5748 100644 --- a/voxtream/config.py +++ b/voxtream/config.py @@ -1,5 +1,8 @@ -from dataclasses import dataclass -from typing import Dict +import json +from dataclasses import dataclass, fields +from importlib import resources +from pathlib import Path +from typing import Any @dataclass @@ -30,13 +33,13 @@ class SpeechGeneratorConfig: spk_enc_model_name: str spk_enc_train_type: str spk_enc_dataset: str - phoneme_index_map: Dict + phoneme_index_map: dict[str, list[int]] phoneme_dict_name: str max_prompt_sec: int min_prompt_sec: int max_phone_tokens: int cache_prompt: bool - punct_map: Dict + punct_map: dict[str, int] phonemizer: str spk_rate_window_sec: float cfg_gamma: float @@ -53,3 +56,132 @@ class SpeechGeneratorConfig: min_speech_seg_sec: float min_look_ahead_phones: int frame_repeat_counter: int + + +def resolve_data_path(path: str | Path, package_relative_path: str) -> Path: + """Resolve a user path, repo checkout path, or packaged data resource.""" + candidate = Path(path) + if candidate.exists(): + return candidate + + repo_candidate = Path(__file__).resolve().parents[1] / candidate + if repo_candidate.exists(): + return repo_candidate + + resource = resources.files("voxtream").joinpath(package_relative_path) + if resource.is_file(): + with resources.as_file(resource) as resource_path: + return resource_path + + raise FileNotFoundError( + f"Could not find {path}. Pass an explicit path or reinstall voxtream with package data." + ) + + +def load_json(path: str | Path, package_relative_path: str) -> object: + with resolve_data_path(path, package_relative_path).open() as f: + return json.load(f) + + +def _json_object(raw: object, label: str) -> dict[str, Any]: + if not isinstance(raw, dict): + raise ValueError(f"{label} must be a JSON object") + return raw + + +def load_generator_config(path: str | Path = "configs/generator.json") -> SpeechGeneratorConfig: + raw = _json_object(load_json(path, "configs/generator.json"), "Generator config") + + field_names = {field.name for field in fields(SpeechGeneratorConfig)} + missing = sorted(field_names - raw.keys()) + unknown = sorted(raw.keys() - field_names) + if missing: + raise ValueError(f"Generator config missing required fields: {missing}") + if unknown: + raise ValueError(f"Generator config has unknown fields: {unknown}") + + config = SpeechGeneratorConfig(**raw) + validate_generator_config(config) + return config + + +def validate_generator_config(config: SpeechGeneratorConfig) -> None: + positive_int_fields = ( + "num_codebooks", + "num_phones_per_frame", + "mimi_sr", + "mimi_vocab_size", + "mimi_frame_ms", + "spk_enc_sr", + "max_prompt_sec", + "min_prompt_sec", + "max_phone_tokens", + "text_context_length", + "audio_pad_token", + "min_look_ahead_phones", + "frame_repeat_counter", + ) + for name in positive_int_fields: + value = getattr(config, name) + if not isinstance(value, int) or value <= 0: + raise ValueError(f"{name} must be a positive integer") + + if config.max_prompt_sec < config.min_prompt_sec: + raise ValueError("max_prompt_sec must be greater than or equal to min_prompt_sec") + if config.max_audio_length_ms <= 0: + raise ValueError("max_audio_length_ms must be positive") + if config.audio_delay_frames < 0: + raise ValueError("audio_delay_frames must be non-negative") + if config.temperature <= 0: + raise ValueError("temperature must be positive") + if not 0 < config.top_p <= 1: + raise ValueError("top_p must be in the range (0, 1]") + if config.topk <= 0: + raise ValueError("topk must be positive") + if config.spk_rate_window_sec <= 0: + raise ValueError("spk_rate_window_sec must be positive") + if config.min_speech_seg_sec < 0: + raise ValueError("min_speech_seg_sec must be non-negative") + + for mapping_name in ("phoneme_index_map", "punct_map"): + mapping = getattr(config, mapping_name) + if not isinstance(mapping, dict) or not mapping: + raise ValueError(f"{mapping_name} must be a non-empty object") + + for repo_field in ("model_repo", "mimi_repo", "spk_enc_repo"): + value = getattr(config, repo_field) + if not isinstance(value, str) or "/" not in value: + raise ValueError(f"{repo_field} must be a Hugging Face or torch.hub repo id") + + +def load_speaking_rate_config( + path: str | Path = "configs/speaking_rate.json", +) -> dict[str, dict[str, list[int] | float]]: + raw = _json_object( + load_json(path, "configs/speaking_rate.json"), "Speaking-rate config" + ) + if not raw: + raise ValueError("Speaking-rate config must be a non-empty JSON object") + + parsed: dict[str, dict[str, list[int] | float]] = {} + for rate, params in raw.items(): + if not isinstance(params, dict): + raise ValueError(f"Speaking-rate entry {rate!r} must be an object") + duration_state = params.get("duration_state") + weight = params.get("weight") + cfg_gamma = params.get("cfg_gamma") + if not isinstance(duration_state, list) or not duration_state: + raise ValueError(f"duration_state for speaking rate {rate!r} must be a non-empty list") + if not all(isinstance(value, int) and value > 0 for value in duration_state): + raise ValueError(f"duration_state for speaking rate {rate!r} must contain positive integers") + duration_state_values = [int(value) for value in duration_state] + if not isinstance(weight, (int, float)) or weight <= 0: + raise ValueError(f"weight for speaking rate {rate!r} must be positive") + if not isinstance(cfg_gamma, (int, float)) or cfg_gamma <= 0: + raise ValueError(f"cfg_gamma for speaking rate {rate!r} must be positive") + parsed[str(rate)] = { + "duration_state": duration_state_values, + "weight": float(weight), + "cfg_gamma": float(cfg_gamma), + } + return parsed diff --git a/voxtream/configs/app.json b/voxtream/configs/app.json new file mode 100644 index 0000000..e7cd171 --- /dev/null +++ b/voxtream/configs/app.json @@ -0,0 +1,23 @@ +{ + "min_chunk_sec": 0.01, + "fade_out_sec": 0.1, + "plot_window_sec": 10.0, + "visual_update_sec": 0.25, + "future_phone_limit": 25, + "plot_width": 1000, + "plot_height": 224, + "plot_left": 74, + "plot_right": 22, + "plot_top": 50, + "plot_bottom": 42, + "plot_y_max": 7, + "plot_y_tick": 1, + "plot_x_tick_sec": 1, + "audio_stream_start_delay_sec": 0.12, + "audio_stream_sample_rate": 24000, + "speaking_rate_min": 1.0, + "speaking_rate_max": 7.0, + "speaking_rate_step": 0.1, + "speaking_rate_default": 4.0, + "min_streaming_rtf": 0.95 +} diff --git a/voxtream/configs/generator.json b/voxtream/configs/generator.json new file mode 100644 index 0000000..f6440bf --- /dev/null +++ b/voxtream/configs/generator.json @@ -0,0 +1,58 @@ +{ + "sil_token": 120, + "bos_token": 123, + "eos_token": 124, + "unk_token": 122, + "eop_token": 122, + "num_codebooks": 16, + "num_phones_per_frame": 2, + "audio_delay_frames": 1, + "temperature": 0.9, + "topk": 5, + "top_p": 0.9, + "max_audio_length_ms": 60000, + "model_repo": "herimor/voxtream2", + "model_name": "model.safetensors", + "model_config_name": "config.json", + "mimi_sr": 24000, + "mimi_vocab_size": 2048, + "mimi_frame_ms": 80, + "mimi_repo": "kyutai/moshiko-pytorch-bf16", + "mimi_name": "tokenizer-e351c8d8-checkpoint125.safetensors", + "spk_enc_sr": 16000, + "spk_enc_repo": "IDRnD/ReDimNet", + "spk_enc_model": "ReDimNet", + "spk_enc_model_name": "M", + "spk_enc_train_type": "ft_mix", + "spk_enc_dataset": "vb2+vox2+cnc", + "phoneme_dict_name": "phoneme_to_token.json", + "max_prompt_sec": 20, + "min_prompt_sec": 1, + "max_phone_tokens": 2000, + "cache_prompt": false, + "cfg_gamma": 1.5, + "cfg_ac_gamma": 3.0, + "text_context": " context", + "text_context_length": 18, + "spk_proj_weight": 1.5, + "audio_pad_token": 2049, + "enhance_prompt": false, + "sidon_se_reload_model": false, + "reset_streaming_state": false, + "hf_token": null, + "apply_vad": false, + "min_speech_seg_sec": 0.3, + "min_look_ahead_phones": 3, + "phonemizer": "espeak", + "spk_rate_window_sec": 3.0, + "frame_repeat_counter": 25, + "punct_map": { + ".": 117, + ",": 118, + "?": 119, + "!": 121 + }, + "phoneme_index_map":{ + "0": [0, 1], "1": [0, 2], "2": [1, 1], "3": [1, 2], "4": [2, 1], "5": [2, 2] + } +} diff --git a/voxtream/configs/speaking_rate.json b/voxtream/configs/speaking_rate.json new file mode 100644 index 0000000..05b403f --- /dev/null +++ b/voxtream/configs/speaking_rate.json @@ -0,0 +1,51 @@ +{ + "1": { + "duration_state": [ + 52, 15, 1, 2, 1, 3 + ], + "weight": 3.0, + "cfg_gamma": 1.25 + }, + "2": { + "duration_state": [ + 44, 24, 2, 2, 1, 1 + ], + "weight": 3.0, + "cfg_gamma": 1.25 + }, + "3": { + "duration_state": [ + 24, 23, 2, 3, 1, 1 + ], + "weight": 5.0, + "cfg_gamma": 1.5 + }, + "4": { + "duration_state": [ + 40, 59, 5, 12, 1, 2 + ], + "weight": 5.0, + "cfg_gamma": 1.5 + }, + "5": { + "duration_state": [ + 21, 43, 4, 13, 1, 2 + ], + "weight": 7.0, + "cfg_gamma": 2.0 + }, + "6": { + "duration_state": [ + 27, 69, 6, 30, 2, 7 + ], + "weight": 7.0, + "cfg_gamma": 2.0 + }, + "7": { + "duration_state": [ + 6, 18, 2, 9, 1, 3 + ], + "weight": 10.0, + "cfg_gamma": 2.5 + } +} \ No newline at end of file diff --git a/voxtream/dataset.py b/voxtream/dataset.py index 7df0cf5..fc9de1d 100644 --- a/voxtream/dataset.py +++ b/voxtream/dataset.py @@ -96,17 +96,18 @@ def __init__( ): # Load data data = {} + data_chunks = {} for name, dataset in datasets.items(): for key, path in dataset.items(): if path is None: data[key] = None continue - val = np.load(base_dir / name / path, allow_pickle=True)[()] - if key not in data: - data[key] = val - else: - data[key] = np.concatenate((data[key], val), axis=0) + val = np.load(base_dir / name / path, allow_pickle=False) + data_chunks.setdefault(key, []).append(val) + + for key, chunks in data_chunks.items(): + data[key] = chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) # Store loaded arrays as attributes for key, val in data.items(): @@ -318,14 +319,11 @@ def __getitem__(self, idx: int) -> Tuple[np.ndarray, ...]: ) -from functools import partial - -import hydra -from torch.utils.data import DataLoader +def main(cfg): + from functools import partial + from torch.utils.data import DataLoader -@hydra.main(version_base=None, config_path="../configs", config_name="train.yaml") -def main(cfg): train_dataset = TrainDataset( base_dir=Path(cfg.dataset_base_dir), phone_vocab_size=cfg.model.phone_vocab_size, @@ -351,4 +349,9 @@ def main(cfg): if __name__ == "__main__": + import hydra + + main = hydra.main(version_base=None, config_path="../configs", config_name="train.yaml")( + main + ) main() diff --git a/voxtream/generator.py b/voxtream/generator.py index 3c3507c..f5fd246 100644 --- a/voxtream/generator.py +++ b/voxtream/generator.py @@ -1,7 +1,7 @@ import logging import time from pathlib import Path -from typing import Dict, Generator, Iterator +from typing import Dict, Generator, Iterator, Optional import numpy as np import torch @@ -145,12 +145,12 @@ def _ensure_mimi_streaming(self, batch_size: int = 1) -> None: def generate_stream( self, prompt_audio_path: Path, - text: str | Generator[str, None, None], - speaking_rate: Iterator[float] = None, - enhance_prompt: bool = None, - apply_vad: bool = None, + text: str | Iterator[str | None], + speaking_rate: Optional[Iterator[float]] = None, + enhance_prompt: bool | None = None, + apply_vad: bool | None = None, return_progress: bool = False, - min_streaming_rtf: float = None, + min_streaming_rtf: float | None = None, ) -> Generator[ tuple[np.ndarray, float] | tuple[np.ndarray, float, Dict], None, None ]: diff --git a/voxtream/model.py b/voxtream/model.py index 6b777fe..6dd88ff 100644 --- a/voxtream/model.py +++ b/voxtream/model.py @@ -48,13 +48,13 @@ def __init__(self, config: ModelConfig, compile_forward: bool = False): self._dep_former_init = self._dep_former self.phone_former, phone_former_dim = prepare_transformer( - MODEL_POOL[config.phone_former] + MODEL_POOL[config.phone_former]() ) self.temp_former, temp_former_dim = prepare_transformer( - MODEL_POOL[config.temp_former] + MODEL_POOL[config.temp_former]() ) self.dep_former, dep_former_dim = prepare_transformer( - MODEL_POOL[config.dep_former] + MODEL_POOL[config.dep_former]() ) self.phone_embeddings = nn.Embedding(config.phone_vocab_size, phone_former_dim) diff --git a/voxtream/run.py b/voxtream/run.py index 67127e3..8d3d1ba 100644 --- a/voxtream/run.py +++ b/voxtream/run.py @@ -1,26 +1,33 @@ import argparse -import json from itertools import repeat from pathlib import Path +from typing import Any, cast import numpy as np import soundfile as sf -from voxtream.config import SpeechGeneratorConfig +from voxtream.config import ( + load_generator_config, + load_speaking_rate_config, + resolve_data_path, +) from voxtream.generator import SpeechGenerator from voxtream.utils.generator import ( - existing_file, set_seed, text_generator, ) +def _audio_frame(result: tuple[Any, ...]) -> np.ndarray[Any, Any]: + return cast(np.ndarray[Any, Any], result[0]) + + def main(): parser = argparse.ArgumentParser() parser.add_argument( "-pa", "--prompt-audio", - type=existing_file, + type=Path, help="Path to the prompt audio file (5-10 sec of target voice. Max 20 sec).", default="assets/audio/english_male.wav", ) @@ -40,13 +47,13 @@ def main(): parser.add_argument( "-c", "--config", - type=existing_file, + type=Path, help="Path to the config file", default="configs/generator.json", ) parser.add_argument( "--spk-rate-config", - type=existing_file, + type=Path, help="Path to the speaking rate config file", default="configs/speaking_rate.json", ) @@ -68,29 +75,29 @@ def main(): args = parser.parse_args() set_seed() - with open(args.config) as f: - config = SpeechGeneratorConfig(**json.load(f)) - - with open(args.spk_rate_config) as f: - spk_rate_config = json.load(f) + config = load_generator_config(args.config) + spk_rate_config = load_speaking_rate_config(args.spk_rate_config) speech_generator = SpeechGenerator(config, spk_rate_config) if args.text is None: speech_generator.logger.error("No text provided.") - exit(0) + raise SystemExit(2) speaking_rate = repeat(args.spk_rate) if args.spk_rate is not None else None speech_stream = speech_generator.generate_stream( - prompt_audio_path=Path(args.prompt_audio), + prompt_audio_path=resolve_data_path( + args.prompt_audio, "assets/audio/english_male.wav" + ), text=text_generator(args.text) if args.full_stream else args.text, speaking_rate=speaking_rate, enhance_prompt=args.prompt_enhancement, ) - audio_frames = [audio_frame for audio_frame, _ in speech_stream] - sf.write(args.output, np.concatenate(audio_frames), config.mimi_sr) + with sf.SoundFile(args.output, "w", samplerate=config.mimi_sr, channels=1) as f: + for result in speech_stream: + f.write(_audio_frame(result)) speech_generator.logger.info(f"Audio saved to {args.output}") diff --git a/voxtream/server.py b/voxtream/server.py index 1118eeb..6744efe 100644 --- a/voxtream/server.py +++ b/voxtream/server.py @@ -36,42 +36,81 @@ import asyncio import base64 +import binascii import json +import logging import os import re import tempfile +import threading from contextlib import asynccontextmanager from pathlib import Path -from typing import Iterator, Optional +from typing import Any, Iterator, Optional, cast import numpy as np import uvicorn from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse -from voxtream.generator import SpeechGenerator, SpeechGeneratorConfig # type: ignore +from voxtream.config import ( # type: ignore + SpeechGeneratorConfig, + load_generator_config, + load_speaking_rate_config, +) +from voxtream.generator import SpeechGenerator # type: ignore from voxtream.utils.generator import set_seed # type: ignore # ---------- Helpers ---------- DATA_URL_RE = re.compile(r"^data:.*?;base64,(.*)$", re.IGNORECASE) +ALLOWED_PROMPT_SUFFIXES = {".flac", ".m4a", ".mp3", ".ogg", ".wav"} +MAX_PROMPT_AUDIO_BYTES = 25 * 1024 * 1024 +MAX_TEXT_CHUNK_CHARS = 4_000 +MAX_INITIAL_TEXT_CHARS = 20_000 +MAX_GENERATION_WORKERS = 1 +LOGGER = logging.getLogger("voxtream.server") -def _b64_to_bytes(s: str) -> bytes: +def _b64_to_bytes(s: str, max_bytes: int = MAX_PROMPT_AUDIO_BYTES) -> bytes: m = DATA_URL_RE.match(s) payload = m.group(1) if m else s - return base64.b64decode(payload) + compact_payload = "".join(payload.split()) + if len(compact_payload) > ((max_bytes + 2) // 3) * 4: + raise ValueError(f"Prompt audio upload exceeds {max_bytes} bytes") + try: + raw = base64.b64decode(compact_payload, validate=True) + except binascii.Error as err: + raise ValueError("prompt_audio_b64 is not valid base64") from err + if len(raw) > max_bytes: + raise ValueError(f"Prompt audio upload exceeds {max_bytes} bytes") + return raw + + +def _validate_prompt_audio_path(prompt_audio_path: str, prompt_root: Path) -> Path: + path = Path(prompt_audio_path).expanduser().resolve() + prompt_root = prompt_root.expanduser().resolve() + if not path.is_relative_to(prompt_root): + raise ValueError(f"prompt_audio_path must be inside {prompt_root}") + if not path.is_file(): + raise ValueError("prompt_audio_path must point to an existing file") + if path.suffix.lower() not in ALLOWED_PROMPT_SUFFIXES: + raise ValueError( + f"prompt_audio_path must use one of {sorted(ALLOWED_PROMPT_SUFFIXES)}" + ) + if path.stat().st_size > MAX_PROMPT_AUDIO_BYTES: + raise ValueError(f"Prompt audio file exceeds {MAX_PROMPT_AUDIO_BYTES} bytes") + return path def _ensure_prompt_audio_file( - prompt_audio_path: Optional[str], prompt_audio_b64: Optional[str] -) -> Path: + prompt_audio_path: Optional[str], prompt_audio_b64: Optional[str], prompt_root: Path +) -> tuple[Path, bool]: """ Returns a filesystem Path to the prompt audio. If base64 is provided, writes a temp file with the decoded bytes (wav/ogg input supported by voxtream). """ if prompt_audio_path: - return Path(prompt_audio_path) + return _validate_prompt_audio_path(prompt_audio_path, prompt_root), False if not prompt_audio_b64: raise ValueError( "Either 'prompt_audio_path' or 'prompt_audio_b64' must be provided." @@ -84,7 +123,7 @@ def _ensure_prompt_audio_file( fd, tmp = tempfile.mkstemp(prefix="voxtream_prompt_", suffix=suffix) with os.fdopen(fd, "wb") as f: f.write(raw) - return Path(tmp) + return Path(tmp), True def get_generator_from_state( @@ -101,15 +140,13 @@ def get_generator_from_state( @asynccontextmanager async def lifespan(app: FastAPI): set_seed() - config_path = "configs/generator.json" - with open(config_path) as f: - config = SpeechGeneratorConfig(**json.load(f)) - spk_rate_config_path = "configs/speaking_rate.json" - with open(spk_rate_config_path) as f: - spk_rate_config = json.load(f) + config = load_generator_config() + spk_rate_config = load_speaking_rate_config() app.state.config = config app.state.speech_generator = SpeechGenerator(config, spk_rate_config) + app.state.generation_semaphore = threading.BoundedSemaphore(MAX_GENERATION_WORKERS) + app.state.prompt_root = Path(os.environ.get("VOXTREAM_PROMPT_ROOT", ".")).resolve() try: yield @@ -198,16 +235,51 @@ async def synthesis(ws: WebSocket): await ws.close() return - prompt_audio_path: Optional[str] = init.get("prompt_audio_path") - prompt_audio_b64: Optional[str] = init.get("prompt_audio_b64") - text_initial: Optional[str] = init.get("text") + prompt_audio_path_value = init.get("prompt_audio_path") + prompt_audio_b64_value = init.get("prompt_audio_b64") + text_initial_value = init.get("text") + if prompt_audio_path_value is not None and not isinstance(prompt_audio_path_value, str): + await ws.send_text( + json.dumps({"type": "error", "message": "prompt_audio_path must be a string"}) + ) + await ws.close() + return + if prompt_audio_b64_value is not None and not isinstance(prompt_audio_b64_value, str): + await ws.send_text( + json.dumps({"type": "error", "message": "prompt_audio_b64 must be a string"}) + ) + await ws.close() + return + if text_initial_value is not None and not isinstance(text_initial_value, str): + await ws.send_text( + json.dumps({"type": "error", "message": "text must be a string"}) + ) + await ws.close() + return + + prompt_audio_path: Optional[str] = prompt_audio_path_value + prompt_audio_b64: Optional[str] = prompt_audio_b64_value + text_initial: Optional[str] = text_initial_value + if text_initial is not None and len(text_initial) > MAX_INITIAL_TEXT_CHARS: + await ws.send_text( + json.dumps({"type": "error", "message": "Initial text is too long"}) + ) + await ws.close() + return full_stream: bool = bool(init.get("full_stream", False)) # Optional: override sample rate (if you down/up-sample on client); we always generate at config.mimi_sr. sample_rate = config.mimi_sr + temp_prompt_path: Optional[Path] = None try: - prompt_path = _ensure_prompt_audio_file(prompt_audio_path, prompt_audio_b64) + prompt_path, is_temp_prompt = _ensure_prompt_audio_file( + prompt_audio_path, + prompt_audio_b64, + prompt_root=ws.app.state.prompt_root, + ) + if is_temp_prompt: + temp_prompt_path = prompt_path except Exception as e: await ws.send_text( json.dumps({"type": "error", "message": f"Invalid prompt audio: {e}"}) @@ -229,18 +301,17 @@ async def synthesis(ws: WebSocket): # --- 2) Prepare text source (string OR generator) --- # If full_stream, we build an iterator fed by subsequent websocket messages. - text_source: str | Iterator[str] - queue: Optional["asyncio.Queue[Optional[str]]"] = None + text_source: str | Iterator[str | None] # --- Prepare text source --- if full_stream: - queue: "asyncio.Queue[Optional[str]]" = asyncio.Queue() + text_queue: "asyncio.Queue[Optional[str]]" = asyncio.Queue() feeder_done = asyncio.Event() async def recv_text_chunks(): try: if text_initial: - await queue.put(text_initial) + await text_queue.put(text_initial) while True: msg = await ws.receive() if msg["type"] == "websocket.disconnect": @@ -250,29 +321,34 @@ async def recv_text_chunks(): payload = json.loads(msg["text"]) except json.JSONDecodeError: # treat raw text as a chunk - await queue.put(msg["text"]) + await text_queue.put(msg["text"]) continue ev = payload.get("event") if ev == "text": chunk = payload.get("chunk", "") + if not isinstance(chunk, str): + continue + if len(chunk) > MAX_TEXT_CHUNK_CHARS: + await text_queue.put(None) + break if chunk: - await queue.put(chunk) + await text_queue.put(chunk) elif ev == "eot": break # ignore binary finally: - await queue.put(None) # signal end-of-text + await text_queue.put(None) # signal end-of-text feeder_done.set() asyncio.create_task(recv_text_chunks()) - text_source = _QueueIterator(queue, loop) # <— pass loop in + text_source = _QueueIterator(text_queue, loop) # <— pass loop in else: # ... (unchanged one-shot fallback) text_source = text_initial or "" # --- Streaming out audio (reworked worker error path) --- - audio_q: "asyncio.Queue[tuple[Optional[np.ndarray], Optional[str]]]" = ( + audio_q: "asyncio.Queue[tuple[np.ndarray[Any, Any] | None, str | None]]" = ( asyncio.Queue(maxsize=8) ) done_evt = asyncio.Event() @@ -280,19 +356,34 @@ async def recv_text_chunks(): def _run_generator(): return speech_generator.generate_stream( prompt_audio_path=prompt_path, - text=text_source, + text=iter(text_source) if not isinstance(text_source, str) else text_source, ) def _worker(): err: Optional[str] = None + acquired_generation_slot = False try: - for audio_frame, _meta in _run_generator(): + semaphore = ws.app.state.generation_semaphore + acquired = semaphore.acquire(blocking=False) + if not acquired: + err = "Server is busy; try again later." + return + acquired_generation_slot = True + for result in _run_generator(): + audio_frame = result[0] asyncio.run_coroutine_threadsafe( - audio_q.put((audio_frame, None)), loop + audio_q.put((cast(np.ndarray[Any, Any], audio_frame), None)), loop ).result() except Exception as e: err = str(e) finally: + if acquired_generation_slot: + ws.app.state.generation_semaphore.release() + if temp_prompt_path is not None: + try: + temp_prompt_path.unlink(missing_ok=True) + except OSError: + LOGGER.warning("Failed to delete temp prompt %s", temp_prompt_path) # Always send poison pill; attach error message once asyncio.run_coroutine_threadsafe( audio_q.put((None, err)), loop @@ -300,22 +391,8 @@ def _worker(): # done_evt.set() is not a coroutine; schedule it thread-safely loop.call_soon_threadsafe(done_evt.set) - import threading - threading.Thread(target=_worker, daemon=True).start() - # Tell client config before streaming - await ws.send_text( - json.dumps( - { - "type": "config", - "sample_rate": config.mimi_sr, - "dtype": "float32", - "channels": 1, - } - ) - ) - # Pump frames out; if we receive an error, send it (if socket still open) then end. while True: frame, err = await audio_q.get() @@ -326,8 +403,10 @@ def _worker(): await ws.send_text( json.dumps({"type": "error", "message": err}) ) - except Exception: - pass + except WebSocketDisconnect: + return + except Exception as send_err: + LOGGER.debug("Failed to send websocket error", exc_info=send_err) break if frame.dtype != np.float32: frame = frame.astype(np.float32, copy=False) @@ -337,12 +416,14 @@ def _worker(): # graceful end try: await ws.send_text(json.dumps({"type": "eos"})) - except Exception: - pass + except WebSocketDisconnect: + return + except Exception as send_err: + LOGGER.debug("Failed to send websocket eos", exc_info=send_err) try: await ws.close() - except Exception: - pass + except Exception as close_err: + LOGGER.debug("Failed to close websocket", exc_info=close_err) except WebSocketDisconnect: return @@ -351,8 +432,8 @@ def _worker(): try: await ws.send_text(json.dumps({"type": "error", "message": str(e)})) await ws.close() - except Exception: - pass + except Exception as send_err: + LOGGER.debug("Failed to send last-ditch websocket error", exc_info=send_err) def main(): diff --git a/voxtream/train.py b/voxtream/train.py index 0595646..f2a6c12 100644 --- a/voxtream/train.py +++ b/voxtream/train.py @@ -35,7 +35,11 @@ def main(cfg: DictConfig) -> None: if cfg.dataset_base_dir is not None: base_dir = cfg.dataset_base_dir else: - base_dir = snapshot_download(cfg.dataset_repo, repo_type="dataset") + base_dir = snapshot_download( + cfg.dataset_repo, + repo_type="dataset", + revision=cfg.get("dataset_revision"), + ) train_dataset = TrainDataset( base_dir=Path(base_dir), @@ -53,6 +57,9 @@ def main(cfg: DictConfig) -> None: num_workers=cfg.num_workers, collate_fn=collate_func, drop_last=True, + pin_memory=torch.cuda.is_available(), + persistent_workers=cfg.num_workers > 0, + prefetch_factor=cfg.get("prefetch_factor", 2) if cfg.num_workers > 0 else None, ) # Callbacks diff --git a/voxtream/trainer.py b/voxtream/trainer.py index 54719ad..31f470f 100644 --- a/voxtream/trainer.py +++ b/voxtream/trainer.py @@ -28,7 +28,11 @@ def __init__(self, config: DictConfig) -> None: self.model = Model(model_config, config.compile_forward) if config.model_weight_path is not None: - weights = torch.load(config.model_weight_path, map_location="cpu") + weights = torch.load( + config.model_weight_path, + map_location="cpu", + weights_only=True, + ) state_dict = {} # Remove torch lightning 'model.' prefix for k, v in weights["state_dict"].items(): diff --git a/voxtream/utils/dataset/clap_ipa_aligner.py b/voxtream/utils/dataset/clap_ipa_aligner.py index e6c9ef2..6293fab 100644 --- a/voxtream/utils/dataset/clap_ipa_aligner.py +++ b/voxtream/utils/dataset/clap_ipa_aligner.py @@ -81,6 +81,9 @@ def __getitem__(self, idx): shuffle=False, num_workers=args.num_workers, collate_fn=collate_varlen, + pin_memory=True, + persistent_workers=args.num_workers > 0, + prefetch_factor=2 if args.num_workers > 0 else None, ) os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) diff --git a/voxtream/utils/dataset/mimi.py b/voxtream/utils/dataset/mimi.py index fdb5da9..bc21551 100644 --- a/voxtream/utils/dataset/mimi.py +++ b/voxtream/utils/dataset/mimi.py @@ -104,7 +104,13 @@ def __getitem__(self, idx): sep=args.sep, ) dataloader = DataLoader( - dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=torch.cuda.is_available(), + persistent_workers=args.num_workers > 0, + prefetch_factor=2 if args.num_workers > 0 else None, ) device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu" diff --git a/voxtream/utils/dataset/speaker_encoder.py b/voxtream/utils/dataset/speaker_encoder.py index a96ee04..888fb5a 100644 --- a/voxtream/utils/dataset/speaker_encoder.py +++ b/voxtream/utils/dataset/speaker_encoder.py @@ -99,6 +99,9 @@ def __getitem__(self, idx): batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, + pin_memory=torch.cuda.is_available(), + persistent_workers=args.num_workers > 0, + prefetch_factor=2 if args.num_workers > 0 else None, ) device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu" diff --git a/voxtream/utils/generator/prompt.py b/voxtream/utils/generator/prompt.py index fd4121a..cc0ba4b 100644 --- a/voxtream/utils/generator/prompt.py +++ b/voxtream/utils/generator/prompt.py @@ -1,3 +1,5 @@ +import hashlib +import json from pathlib import Path from typing import Tuple @@ -12,6 +14,68 @@ from voxtream.utils.generator.helpers import autocast_ctx +PROMPT_CACHE_VERSION = 1 + + +def _prompt_cache_path(prompt_audio_path: Path) -> Path: + return prompt_audio_path.parent / f"{prompt_audio_path.stem}.prompt.npz" + + +def _prompt_cache_metadata( + prompt_audio_path: Path, + config: SpeechGeneratorConfig, + enhance_prompt: bool, + apply_vad: bool, +) -> dict[str, object]: + stat = prompt_audio_path.stat() + key_material = { + "version": PROMPT_CACHE_VERSION, + "path": str(prompt_audio_path.resolve()), + "mtime_ns": stat.st_mtime_ns, + "size": stat.st_size, + "model_repo": config.model_repo, + "model_name": config.model_name, + "mimi_repo": config.mimi_repo, + "mimi_name": config.mimi_name, + "spk_enc_repo": config.spk_enc_repo, + "spk_enc_model": config.spk_enc_model, + "spk_enc_model_name": config.spk_enc_model_name, + "num_codebooks": config.num_codebooks, + "audio_delay_frames": config.audio_delay_frames, + "mimi_sr": config.mimi_sr, + "spk_enc_sr": config.spk_enc_sr, + "enhance_prompt": enhance_prompt, + "apply_vad": apply_vad, + "min_speech_seg_sec": config.min_speech_seg_sec, + } + encoded = json.dumps(key_material, sort_keys=True, separators=(",", ":")).encode() + return {"cache_key": hashlib.sha256(encoded).hexdigest(), **key_material} + + +def _load_prompt_cache(prompt_path: Path, expected_metadata: dict[str, object]): + with np.load(prompt_path, allow_pickle=False) as prompt_data: + if "metadata" not in prompt_data: + return None + metadata = json.loads(str(prompt_data["metadata"].item())) + if metadata.get("cache_key") != expected_metadata["cache_key"]: + return None + return prompt_data["audio_tokens"], prompt_data["spk_embedding"] + + +def _save_prompt_cache( + prompt_path: Path, + metadata: dict[str, object], + audio_tokens: torch.Tensor, + spk_embedding: torch.Tensor, +) -> None: + np.savez_compressed( + prompt_path, + metadata=np.array(json.dumps(metadata, sort_keys=True)), + audio_tokens=audio_tokens.cpu().numpy(), + spk_embedding=spk_embedding.to(device="cpu", dtype=torch.float16).numpy(), + ) + + def encode_audio_prompt( waveform: torch.Tensor, orig_sr: int, @@ -108,6 +172,9 @@ def extract_speech_frames( waveform_16k[:, int(start * vad_sr) : int(end * vad_sr)] ) + if not valid_segments: + raise ValueError("No speech detected in acoustic prompt") + waveform_vad = torch.cat(valid_segments, dim=1) waveform_vad_16k = torch.cat(valid_segments_16k, dim=1) return waveform_vad, waveform_vad_16k @@ -124,18 +191,35 @@ def prepare_prompt( spk_enc: torch.nn.Module, sidon_se, vad: torch.nn.Module, - enhance_prompt: bool = None, - apply_vad: bool = None, + enhance_prompt: bool | None = None, + apply_vad: bool | None = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - prompt_path = prompt_audio_path.parent / f"{prompt_audio_path.stem}.prompt.npy" + enhance_prompt_enabled = bool(enhance_prompt) + apply_vad_enabled = bool(apply_vad) + prompt_path = _prompt_cache_path(prompt_audio_path) + cache_metadata = _prompt_cache_metadata( + prompt_audio_path, + config, + enhance_prompt=enhance_prompt_enabled, + apply_vad=apply_vad_enabled, + ) if config.cache_prompt and prompt_path.exists(): - prompt_data = np.load(prompt_path, allow_pickle=True).item() - audio_tokens = torch.from_numpy(prompt_data["audio_tokens"]).to(device) - spk_embedding = torch.from_numpy(prompt_data["spk_embedding"]).to( - device, dtype=dtype - ) + cached_prompt = _load_prompt_cache(prompt_path, cache_metadata) + if cached_prompt is not None: + audio_tokens_np, spk_embedding_np = cached_prompt + audio_tokens = torch.from_numpy(audio_tokens_np).to(device) + spk_embedding = torch.from_numpy(spk_embedding_np).to(device, dtype=dtype) + else: + logger.warning("Ignoring stale prompt cache at %s", prompt_path) + prompt_path.unlink(missing_ok=True) + audio_tokens = None + spk_embedding = None else: - waveform, orig_sr = torchaudio.load(prompt_audio_path) + audio_tokens = None + spk_embedding = None + + if audio_tokens is None or spk_embedding is None: + waveform, orig_sr = torchaudio.load(str(prompt_audio_path)) if waveform.shape[0] != 1: logger.warning( f"Prompt audio has {waveform.shape[0]} channels; converting to mono by averaging." @@ -143,7 +227,7 @@ def prepare_prompt( waveform = waveform.mean(dim=0, keepdim=True) waveform_16k = None - if apply_vad: + if apply_vad_enabled: waveform_16k = torchaudio.functional.resample( waveform, orig_sr, config.spk_enc_sr ) @@ -156,9 +240,11 @@ def prepare_prompt( min_speech_seg_sec=config.min_speech_seg_sec, ) - assert ( - waveform.shape[1] >= orig_sr * config.min_prompt_sec - ), f"Acoustic prompt is too short ({waveform.shape[1] / orig_sr:.2f} seconds); should be at least {config.min_prompt_sec} seconds." + if waveform.shape[1] < orig_sr * config.min_prompt_sec: + raise ValueError( + f"Acoustic prompt is too short ({waveform.shape[1] / orig_sr:.2f} seconds); " + f"should be at least {config.min_prompt_sec} seconds." + ) if waveform.shape[1] > orig_sr * config.max_prompt_sec: logger.warning( f"Prompt audio is longer than {config.max_prompt_sec} seconds; trimming." @@ -168,7 +254,7 @@ def prepare_prompt( waveform_mimi = waveform sample_rate = orig_sr - if enhance_prompt: + if enhance_prompt_enabled: if waveform_16k is None: waveform_16k = torchaudio.functional.resample( waveform, orig_sr, config.spk_enc_sr @@ -199,15 +285,7 @@ def prepare_prompt( ) if config.cache_prompt: - np.save( - prompt_path, - { - "audio_tokens": audio_tokens.cpu().numpy(), - "spk_embedding": spk_embedding.to( - device="cpu", dtype=torch.float16 - ).numpy(), - }, - ) + _save_prompt_cache(prompt_path, cache_metadata, audio_tokens, spk_embedding) audio_tokens = delay_audio_tokens( audio_tokens=audio_tokens, diff --git a/voxtream/utils/generator/setup.py b/voxtream/utils/generator/setup.py index 94583df..3a27172 100644 --- a/voxtream/utils/generator/setup.py +++ b/voxtream/utils/generator/setup.py @@ -1,6 +1,6 @@ import json from dataclasses import fields -from typing import Dict, Tuple +from typing import Any, Tuple, cast import torch from huggingface_hub import hf_hub_download @@ -16,7 +16,7 @@ def load_generator_model( device: str, dtype: torch.dtype, batch_size: int, -) -> Tuple[Model, Dict]: +) -> Tuple[Model, dict[str, Any]]: model_config_path = hf_hub_download( config.model_repo, config.model_config_name, token=config.hf_token ) @@ -70,14 +70,22 @@ def load_mimi_model( def load_speaker_encoder( config: SpeechGeneratorConfig, device: str, dtype: torch.dtype ) -> torch.nn.Module: - model = torch.hub.load( - config.spk_enc_repo, - config.spk_enc_model, - model_name=config.spk_enc_model_name, - train_type=config.spk_enc_train_type, - dataset=config.spk_enc_dataset, - trust_repo=True, - verbose=False, + if config.spk_enc_repo != "IDRnD/ReDimNet": + raise ValueError( + "Refusing to load untrusted speaker encoder repo. " + "Only IDRnD/ReDimNet is allowed by default." + ) + model = cast( + torch.nn.Module, + torch.hub.load( + config.spk_enc_repo, + config.spk_enc_model, + model_name=config.spk_enc_model_name, + train_type=config.spk_enc_train_type, + dataset=config.spk_enc_dataset, + trust_repo=True, + verbose=False, + ), ).to(device, dtype=dtype) model.spec.float() model.bn.float() diff --git a/voxtream/utils/generator/text.py b/voxtream/utils/generator/text.py index 0c30aa3..debd14d 100644 --- a/voxtream/utils/generator/text.py +++ b/voxtream/utils/generator/text.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from dataclasses import dataclass from typing import Dict, Generator, List @@ -222,7 +223,7 @@ def prepare_non_streaming_text( def prepare_streaming_text( - text_gen: Generator[str, None, None], + text_gen: Iterator[str | None], phone_tokens: torch.Tensor, punct_del_indices: torch.Tensor, empty_text_stream: bool, diff --git a/voxtream/utils/model.py b/voxtream/utils/model.py index 279cd12..0a91663 100644 --- a/voxtream/utils/model.py +++ b/voxtream/utils/model.py @@ -28,17 +28,17 @@ def get_llama3_2( MODEL_POOL = { - "phone_former": get_llama3_2( + "phone_former": lambda: get_llama3_2( num_layers=6, num_heads=8, num_kv_heads=2, embed_dim=1024, intermediate_dim=4096 ), - "temp_former": get_llama3_2( + "temp_former": lambda: get_llama3_2( num_layers=12, num_heads=16, num_kv_heads=4, embed_dim=1024, intermediate_dim=4096, ), - "dep_former_csm": get_llama3_2( + "dep_former_csm": lambda: get_llama3_2( num_layers=4, num_heads=8, num_kv_heads=2, embed_dim=1024, intermediate_dim=8192 ), }