Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified assets/test/reference_frames.npy
Binary file not shown.
2 changes: 2 additions & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ hydra:
exp_name: voxtream_train
model_repo: herimor/voxtream2
dataset_repo: herimor/voxtream2-train
dataset_revision: null
dataset_base_dir: null
dep_former_name: dep_former_csm.safetensors
dep_former_weight_path: null
Expand All @@ -21,6 +22,7 @@ precision: bf16-mixed
gradient_clip_val: 0.5
seed: 42
num_workers: 8
prefetch_factor: 2
log_every_n_steps: 50
gpus: -1
logging_interval: step
Expand Down
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@ build-backend = "setuptools.build_meta"

[tool.setuptools]
license-files = ["LICENSE"]
include-package-data = true

[tool.setuptools.package-data]
voxtream = [
"configs/*.json",
"assets/*.json",
"assets/audio/*.wav",
"assets/benchmark/*.csv",
"assets/test/*.wav",
"assets/test/*.npy",
]

[project]
name = "voxtream"
Expand Down Expand Up @@ -55,6 +66,9 @@ dependencies = [
]

[project.optional-dependencies]
server = ["fastapi", "uvicorn"]
client = ["websockets", "sounddevice"]
benchmark = ["pandas", "tqdm"]
dev = ["black", "isort", "flake8", "mypy", "pytest"]

[project.urls]
Expand Down
42 changes: 42 additions & 0 deletions tests/test_config_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json

import pytest

from voxtream.config import load_generator_config, load_speaking_rate_config


def test_load_generator_config_from_default_package_or_repo_path():
config = load_generator_config()

assert config.mimi_sr == 24000
assert config.model_repo == "herimor/voxtream2"


def test_load_generator_config_rejects_unknown_field(tmp_path):
config_path = tmp_path / "generator.json"
config = load_generator_config()
payload = config.__dict__.copy()
payload["unexpected"] = True
config_path.write_text(json.dumps(payload))

with pytest.raises(ValueError, match="unknown fields"):
load_generator_config(config_path)


def test_load_generator_config_rejects_invalid_ranges(tmp_path):
config_path = tmp_path / "generator.json"
config = load_generator_config()
payload = config.__dict__.copy()
payload["mimi_frame_ms"] = 0
config_path.write_text(json.dumps(payload))

with pytest.raises(ValueError, match="mimi_frame_ms"):
load_generator_config(config_path)


def test_load_speaking_rate_config_validates_shape(tmp_path):
config_path = tmp_path / "speaking_rate.json"
config_path.write_text(json.dumps({"1": {"duration_state": [], "weight": 1.0, "cfg_gamma": 1.0}}))

with pytest.raises(ValueError, match="duration_state"):
load_speaking_rate_config(config_path)
38 changes: 38 additions & 0 deletions tests/test_dataset_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np

from voxtream.dataset import TrainDataset


def test_train_dataset_loads_shards_without_pickle(monkeypatch, tmp_path):
calls = []

def fake_load(path, allow_pickle=False):
calls.append((path, allow_pickle))
return np.array([1])

monkeypatch.setattr(np, "load", fake_load)
monkeypatch.setattr(TrainDataset, "__len__", lambda self: 1)

dataset = TrainDataset.__new__(TrainDataset)
TrainDataset.__init__(
dataset,
base_dir=tmp_path,
datasets={"one": {"audio_codes": "a.npy"}, "two": {"audio_codes": "b.npy"}},
phone_vocab_size=10,
audio_vocab_size=10,
audio_pad_size=0,
num_codebooks=2,
audio_delay_frames=1,
dtype="int64",
audio_window_size=1,
pad_len=1,
semantic_label_pad=2,
num_phones_per_frame=2,
phoneme_index_map={"0": 1},
cfg_prob=0.0,
prompt_length_sec=[1, 2],
mimi_tps=12.5,
)

assert all(allow_pickle is False for _, allow_pickle in calls)
assert dataset.audio_codes.tolist() == [1, 1]
11 changes: 11 additions & 0 deletions tests/test_model_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
def test_model_pool_factories_return_distinct_instances():
pytest = __import__("pytest")
pytest.importorskip("torch")
pytest.importorskip("torchtune")

from voxtream.utils.model import MODEL_POOL

first = MODEL_POOL["phone_former"]()
second = MODEL_POOL["phone_former"]()

assert first is not second
43 changes: 43 additions & 0 deletions tests/test_prompt_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
import pytest

from voxtream.config import load_generator_config
from voxtream.utils.generator.prompt import (
_load_prompt_cache,
_prompt_cache_metadata,
_prompt_cache_path,
_save_prompt_cache,
)


def test_prompt_cache_uses_npz_without_pickle(tmp_path):
torch = pytest.importorskip("torch")
prompt = tmp_path / "voice.wav"
prompt.write_bytes(b"RIFF")
config = load_generator_config()
cache_path = _prompt_cache_path(prompt)
metadata = _prompt_cache_metadata(prompt, config, enhance_prompt=False, apply_vad=False)
audio_tokens = torch.zeros((1, 2, 3), dtype=torch.int64)
spk_embedding = torch.ones((1, 4), dtype=torch.float32)

_save_prompt_cache(cache_path, metadata, audio_tokens, spk_embedding)

with np.load(cache_path, allow_pickle=False) as data:
assert set(data.files) == {"metadata", "audio_tokens", "spk_embedding"}


def test_prompt_cache_rejects_stale_metadata(tmp_path):
torch = pytest.importorskip("torch")
prompt = tmp_path / "voice.wav"
prompt.write_bytes(b"RIFF")
config = load_generator_config()
cache_path = _prompt_cache_path(prompt)
metadata = _prompt_cache_metadata(prompt, config, enhance_prompt=False, apply_vad=False)
audio_tokens = torch.zeros((1, 2, 3), dtype=torch.int64)
spk_embedding = torch.ones((1, 4), dtype=torch.float32)
_save_prompt_cache(cache_path, metadata, audio_tokens, spk_embedding)

stale_metadata = dict(metadata)
stale_metadata["cache_key"] = "stale"

assert _load_prompt_cache(cache_path, stale_metadata) is None
24 changes: 19 additions & 5 deletions tests/test_run_output_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,24 @@ def test_run_main_output_matches_reference(monkeypatch, tmp_path):

captured_write = {}

def fake_sf_write(output_path, data, samplerate):
captured_write["path"] = str(output_path)
captured_write["data"] = np.asarray(data)
captured_write["samplerate"] = samplerate
class FakeSoundFile:
def __init__(self, output_path, mode, samplerate, channels):
captured_write["path"] = str(output_path)
captured_write["mode"] = mode
captured_write["samplerate"] = samplerate
captured_write["channels"] = channels
captured_write["chunks"] = []

monkeypatch.setattr(run.sf, "write", fake_sf_write)
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
captured_write["data"] = np.concatenate(captured_write["chunks"])

def write(self, data):
captured_write["chunks"].append(np.asarray(data))

monkeypatch.setattr(run.sf, "SoundFile", FakeSoundFile)

output_path = tmp_path / "voxtream_run_ref.wav"
monkeypatch.setattr(
Expand All @@ -53,6 +65,8 @@ def fake_sf_write(output_path, data, samplerate):
run.main()

assert captured_write["path"] == str(output_path)
assert captured_write["mode"] == "w"
assert captured_write["channels"] == 1
assert captured_write["data"].shape == expected_audio.shape
np.testing.assert_allclose(
captured_write["data"], expected_audio, rtol=1e-5, atol=1e-6
Expand Down
58 changes: 58 additions & 0 deletions tests/test_server_prompt_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import base64
from pathlib import Path

import pytest

from voxtream.server import (
MAX_PROMPT_AUDIO_BYTES,
_b64_to_bytes,
_ensure_prompt_audio_file,
_validate_prompt_audio_path,
)


def test_validate_prompt_audio_path_rejects_paths_outside_root(tmp_path):
prompt_root = tmp_path / "allowed"
prompt_root.mkdir()
outside = tmp_path / "outside.wav"
outside.write_bytes(b"RIFF")

with pytest.raises(ValueError, match="inside"):
_validate_prompt_audio_path(str(outside), prompt_root)


def test_validate_prompt_audio_path_accepts_file_inside_root(tmp_path):
prompt_root = tmp_path / "allowed"
prompt_root.mkdir()
prompt = prompt_root / "voice.wav"
prompt.write_bytes(b"RIFF")

assert _validate_prompt_audio_path(str(prompt), prompt_root) == prompt.resolve()


def test_b64_to_bytes_rejects_invalid_base64():
with pytest.raises(ValueError, match="valid base64"):
_b64_to_bytes("not base64?")


def test_b64_to_bytes_rejects_large_payload():
payload = base64.b64encode(b"x" * (MAX_PROMPT_AUDIO_BYTES + 1)).decode()

with pytest.raises(ValueError, match="exceeds"):
_b64_to_bytes(payload)


def test_ensure_prompt_audio_file_creates_temp_for_base64_and_marks_temp(tmp_path):
prompt_path, is_temp = _ensure_prompt_audio_file(
None,
"data:audio/wav;base64," + base64.b64encode(b"RIFF").decode(),
prompt_root=tmp_path,
)

try:
assert is_temp is True
assert isinstance(prompt_path, Path)
assert prompt_path.exists()
assert prompt_path.read_bytes() == b"RIFF"
finally:
prompt_path.unlink(missing_ok=True)
Loading