diff --git a/fluster/utils.py b/fluster/utils.py index 44b9206f..5d7b5d63 100644 --- a/fluster/utils.py +++ b/fluster/utils.py @@ -21,6 +21,7 @@ import contextlib import hashlib import http.client +import io import os import platform import random @@ -320,11 +321,35 @@ def normalize_path(path: str) -> str: def _read_wav(path: str) -> Tuple[array.array[int], int, int]: - """Load a WAV file and return (interleaved_samples, n_channels, sampwidth). Supports 16 and 32-bit PCM.""" - with wave.open(path, "rb") as w: - n_channels = w.getnchannels() - sampwidth = w.getsampwidth() - raw = w.readframes(w.getnframes()) + """Load a WAV file and return (interleaved_samples, n_channels, sampwidth). + Supports 16 and 32-bit PCM and WAVE_FORMAT_EXTENSIBLE (0xFFFE). + Python < 3.12 rejects WAVE_FORMAT_EXTENSIBLE, so on failure we locate + the fmt chunk and patch only the wFormatTag — the PCM data layout is identical.""" + try: + with wave.open(path, "rb") as w: + return _extract_wav(w) + except wave.Error: + pass + # Locate the fmt chunk in the RIFF structure and patch only wFormatTag + with open(path, "rb") as f: + data = bytearray(f.read()) + pos = 12 # skip RIFF + size + WAVE + while pos < len(data) - 8: + chunk_id = data[pos : pos + 4] + chunk_size = int.from_bytes(data[pos + 4 : pos + 8], "little") + if chunk_id == b"fmt ": + if int.from_bytes(data[pos + 8 : pos + 10], "little") == 0xFFFE: + data[pos + 8 : pos + 10] = b"\x01\x00" + break + pos += 8 + chunk_size + (chunk_size % 2) + with wave.open(io.BytesIO(bytes(data))) as w: + return _extract_wav(w) + + +def _extract_wav(w: wave.Wave_read) -> Tuple[array.array[int], int, int]: + n_channels = w.getnchannels() + sampwidth = w.getsampwidth() + raw = w.readframes(w.getnframes()) if sampwidth == 2: typecode = "h" elif sampwidth == 4: