Skip to content
53 changes: 53 additions & 0 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,59 @@
# SPDX-License-Identifier: MIT

import argparse
import functools
import sys

import grpc

# Exit codes shared by the CLI scripts. Pipelines that compose these scripts
# rely on a non-zero status to detect failure; see also `cli_main` below.
EXIT_OK = 0
EXIT_GENERIC_ERROR = 1
EXIT_BAD_INPUT = 2 # malformed args, missing file, empty/whitespace text, ...
EXIT_UNAVAILABLE = 3 # gRPC UNAVAILABLE (server down, wrong port, ...)
EXIT_INVALID_ARGUMENT = 4 # gRPC INVALID_ARGUMENT or NOT_FOUND (bad model/lang/voice)
EXIT_INTERRUPTED = 130 # SIGINT


def _grpc_exit_code(error: grpc.RpcError) -> int:
code = error.code() if callable(getattr(error, "code", None)) else None
if code == grpc.StatusCode.UNAVAILABLE:
return EXIT_UNAVAILABLE
if code in (grpc.StatusCode.INVALID_ARGUMENT, grpc.StatusCode.NOT_FOUND):
return EXIT_INVALID_ARGUMENT
return EXIT_GENERIC_ERROR


def cli_main(func):
"""Translate exceptions raised by a CLI ``main`` into consistent exit codes.

Wrapped function may return an int exit code or ``None`` (treated as
``EXIT_OK``). Unhandled exceptions are caught and mapped: gRPC ``RpcError``
via status code, ``FileNotFoundError`` / ``ValueError`` → ``EXIT_BAD_INPUT``,
anything else → ``EXIT_GENERIC_ERROR``. The error is also printed to stderr
so CI logs surface the cause.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
return EXIT_OK if result is None else int(result)
except KeyboardInterrupt:
return EXIT_INTERRUPTED
except grpc.RpcError as e:
details = e.details() if callable(getattr(e, "details", None)) else str(e)
print(f"Error: {details}", file=sys.stderr)
return _grpc_exit_code(e)
except (FileNotFoundError, IsADirectoryError, ValueError) as e:
print(f"Error: {e}", file=sys.stderr)
return EXIT_BAD_INPUT
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return EXIT_GENERIC_ERROR

return wrapper


def validate_grpc_message_size(value):
"""Validate that the GRPC message size is within acceptable limits."""
Expand Down
21 changes: 19 additions & 2 deletions riva/client/audio_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import queue
from typing import Dict, Union, Optional

import pyaudio

def _require_pyaudio():
try:
import pyaudio
return pyaudio
except ImportError as e:
raise ImportError(
"pyaudio is required for audio device I/O. Install the system PortAudio "
"headers first (e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, "
"`brew install portaudio` on macOS), then `pip install pyaudio`."
) from e


class MicrophoneStream:
Expand All @@ -20,6 +30,8 @@ def __init__(self, rate: int, chunk: int, device: int = None) -> None:
self.closed = True

def __enter__(self):
pyaudio = _require_pyaudio()
self._pa_module = pyaudio
self._audio_interface = pyaudio.PyAudio()
self._audio_stream = self._audio_interface.open(
format=pyaudio.paInt16,
Expand Down Expand Up @@ -50,7 +62,7 @@ def __exit__(self, type, value, traceback):
def _fill_buffer(self, in_data, frame_count, time_info, status_flags):
"""Continuously collect data from the audio stream into the buffer."""
self._buff.put(in_data)
return None, pyaudio.paContinue
return None, self._pa_module.paContinue

def __next__(self) -> bytes:
if self.closed:
Expand All @@ -76,13 +88,15 @@ def __iter__(self):


def get_audio_device_info(device_id: int) -> Dict[str, Union[int, float, str]]:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
info = p.get_device_info_by_index(device_id)
p.terminate()
return info


def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str]]]:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
try:
info = p.get_default_input_device_info()
Expand All @@ -93,6 +107,7 @@ def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str]


def list_output_devices() -> None:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
print("Output audio devices:")
for i in range(p.get_device_count()):
Expand All @@ -104,6 +119,7 @@ def list_output_devices() -> None:


def list_input_devices() -> None:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
print("Input audio devices:")
for i in range(p.get_device_count()):
Expand All @@ -118,6 +134,7 @@ class SoundCallBack:
def __init__(
self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int,
) -> None:
pyaudio = _require_pyaudio()
self.pa = pyaudio.PyAudio()
self.stream = self.pa.open(
output_device_index=output_device_index,
Expand Down
113 changes: 75 additions & 38 deletions riva/client/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import queue
import uuid
from typing import Dict, Any, Generator
from typing import Any, Dict, Generator, List, Optional

import requests
import websockets
Expand Down Expand Up @@ -677,73 +677,94 @@ def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, sect
logger.debug("Updated %s = %s", key, value)

async def _update_session(self, timeout=1):
"""Update session configuration by selectively overriding server defaults."""
"""Update session configuration by sending an override-only payload.

Builds the synthesize_session.update payload directly from CLI args
instead of round-tripping through self.session_config (the response
from POST /v1/realtime/synthesis_sessions). The HTTP response is an
InitialSynthesisSessionConfig and includes id/object/client_secret
fields that are not part of BaseSynthesisSessionConfig (the type the
server expects on the update). Carrying those into the override
payload — and the resulting deep-mutation through _safe_update_config
— was the root cause of --voice silently failing to take effect
against published 26.02 NIMs.
"""
logger.info("Updating session configuration...")
logger.debug("Server default config: %s", self.session_config)

# Create a copy of the session config from server defaults
session_config = self.session_config.copy()

# Track what we're overriding
overrides = []
session_payload: Dict[str, Any] = {}
overrides: List[str] = []
requested_voice: Optional[str] = None
requested_language: Optional[str] = None

# Update input text synthesis - only override if args are provided
if hasattr(self.args, 'language_code') and self.args.language_code:
self._safe_update_config(session_config, "language_code", self.args.language_code, "input_text_synthesis")
# input_text_synthesis: language_code + voice_name
input_text_synthesis: Dict[str, Any] = {}
if getattr(self.args, "language_code", None):
requested_language = self.args.language_code
input_text_synthesis["language_code"] = requested_language
overrides.append("language_code")

if hasattr(self.args, 'voice') and self.args.voice:
self._safe_update_config(session_config, "voice_name", self.args.voice, "input_text_synthesis")
if getattr(self.args, "voice", None):
requested_voice = self.args.voice
input_text_synthesis["voice_name"] = requested_voice
overrides.append("voice_name")
if input_text_synthesis:
session_payload["input_text_synthesis"] = input_text_synthesis

# Update output audio parameters - only override if args are provided
if hasattr(self.args, 'sample_rate_hz') and self.args.sample_rate_hz:
self._safe_update_config(session_config, "sample_rate_hz", self.args.sample_rate_hz, "output_audio_params")
# output_audio_params: sample_rate_hz + audio_format
output_audio_params: Dict[str, Any] = {}
if getattr(self.args, "sample_rate_hz", None):
output_audio_params["sample_rate_hz"] = self.args.sample_rate_hz
overrides.append("sample_rate_hz")

if hasattr(self.args, 'encoding') and self.args.encoding:
self._safe_update_config(session_config, "audio_format", self.args.encoding, "output_audio_params")
if getattr(self.args, "encoding", None):
output_audio_params["audio_format"] = self.args.encoding
overrides.append("audio_format")
if output_audio_params:
session_payload["output_audio_params"] = output_audio_params

# Update custom dictionary - only override if args are provided
if hasattr(self.args, 'custom_dictionary') and self.args.custom_dictionary:
self._safe_update_config(session_config, "custom_dictionary", self.args.custom_dictionary)
if getattr(self.args, "custom_dictionary", None):
session_payload["custom_dictionary"] = self.args.custom_dictionary
overrides.append("custom_dictionary")

# Update zero-shot config - only override if args are provided
if (hasattr(self.args, 'zero_shot_audio_prompt_file') and self.args.zero_shot_audio_prompt_file):
# zero_shot_config: audio bytes + transcript + quality
if getattr(self.args, "zero_shot_audio_prompt_file", None):
zero_shot_config: Dict[str, Any] = {}
try:
with open(self.args.zero_shot_audio_prompt_file, 'rb') as f:
with open(self.args.zero_shot_audio_prompt_file, "rb") as f:
audio_data = f.read()
base64_audio_data = base64.b64encode(audio_data).decode('utf-8')
self._safe_update_config(session_config["zero_shot_config"], "audio_prompt_bytes", base64_audio_data)
logger.info("Zero-shot audio prompt bytes: %s", len(base64_audio_data))
base64_audio_data = base64.b64encode(audio_data).decode("utf-8")
zero_shot_config["audio_prompt_bytes"] = base64_audio_data
logger.info("Zero-shot audio prompt bytes: %s", len(base64_audio_data))
overrides.append("zero_shot_audio_prompt_file")
except Exception as e:
logger.warning("Failed to load zero-shot audio prompt: %s", e)
if hasattr(self.args, 'zero_shot_audio_prompt_transcript') and self.args.zero_shot_audio_prompt_transcript:
self._safe_update_config(session_config["zero_shot_config"], "audio_prompt_transcript", self.args.zero_shot_audio_prompt_transcript)

if getattr(self.args, "zero_shot_audio_prompt_transcript", None):
zero_shot_config["audio_prompt_transcript"] = self.args.zero_shot_audio_prompt_transcript
logger.info("Zero-shot audio prompt transcript: %s", self.args.zero_shot_audio_prompt_transcript)
overrides.append("zero_shot_transcript")
if hasattr(self.args, 'zero_shot_prompt_quality') and self.args.zero_shot_prompt_quality:
self._safe_update_config(session_config["zero_shot_config"], "prompt_quality", self.args.zero_shot_prompt_quality)

if getattr(self.args, "zero_shot_prompt_quality", None):
zero_shot_config["prompt_quality"] = self.args.zero_shot_prompt_quality
logger.info("Zero-shot quality: %s", self.args.zero_shot_prompt_quality)
overrides.append("zero_shot_prompt_quality")

if hasattr(self.args, 'custom_configuration') and self.args.custom_configuration:
if zero_shot_config:
session_payload["zero_shot_config"] = zero_shot_config

if getattr(self.args, "custom_configuration", None):
custom_config = self._parse_custom_configuration(self.args.custom_configuration)
if custom_config:
session_config["custom_configuration"] = custom_config
session_payload["custom_configuration"] = custom_config
overrides.append("custom_configuration")

logger.debug("Overriding parameters: %s", overrides)
logger.info("Overriding session parameters: %s", overrides)
if requested_voice:
logger.info("Requested voice_name=%r", requested_voice)

update_request = {
"event_id": f"event_{uuid.uuid4()}",
"type": "synthesize_session.update",
"session": session_config
"session": session_payload,
}

await self._send_message(update_request)
Expand All @@ -764,6 +785,22 @@ async def _update_session(self, timeout=1):
logger.info("Synthesis session updated successfully")
self.session_config = response_data["session"]
session_updated = True
acked = self.session_config.get("input_text_synthesis", {}) if isinstance(self.session_config, dict) else {}
acked_voice = acked.get("voice_name")
acked_language = acked.get("language_code")
if requested_voice and acked_voice and acked_voice != requested_voice:
logger.warning(
"Server applied voice_name=%r, but --voice requested %r. "
"Synthesis will use the server-applied voice.",
acked_voice, requested_voice,
)
elif requested_voice:
logger.info("Server confirmed voice_name=%r", acked_voice)
if requested_language and acked_language and acked_language != requested_language:
logger.warning(
"Server applied language_code=%r, but --language-code requested %r.",
acked_language, requested_language,
)
elif event_type == "error":
error_info = response_data.get("error", {})
logger.error("Error: %s", error_info.get("message", "Unknown error"))
Expand Down
24 changes: 17 additions & 7 deletions scripts/asr/realtime_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
add_realtime_config_argparse_parameters,
add_connection_argparse_parameters,
)
try:
from riva.client.argparse_utils import cli_main
except ImportError:
def cli_main(func):
return func


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -300,17 +305,22 @@ async def main() -> None:
import riva.client.audio_io
riva.client.audio_io.list_input_devices()
except ModuleNotFoundError:
print("PyAudio not available. Please install PyAudio to list audio devices.")
print(
"PyAudio not available. Install the system PortAudio headers first "
"(e.g. `apt-get install -y portaudio19-dev`), then `pip install pyaudio`.",
file=sys.stderr,
)
return

setup_signal_handler()
await run_transcription(args)

try:
await run_transcription(args)
except Exception as e:
print(f"Fatal error: {e}")
sys.exit(1)

@cli_main
def _entry() -> int:
asyncio.run(main())
return 0


if __name__ == "__main__":
asyncio.run(main())
sys.exit(_entry())
16 changes: 13 additions & 3 deletions scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@
import argparse
import os
import queue
import sys
import time
from pathlib import Path
from threading import Thread
from typing import Union

import riva.client
from riva.client.asr import get_wav_file_parameters
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from riva.client.argparse_utils import (
add_asr_config_argparse_parameters,
add_connection_argparse_parameters,
)
try:
from riva.client.argparse_utils import cli_main
except ImportError:
def cli_main(func):
return func


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -109,7 +118,8 @@ def streaming_transcription_worker(
raise


def main() -> None:
@cli_main
def main() -> int:
args = parse_args()
print("Number of clients:", args.num_clients)
print("Number of iteration:", args.num_iterations)
Expand Down Expand Up @@ -140,4 +150,4 @@ def main() -> None:


if __name__ == "__main__":
main()
sys.exit(main())
Loading