From 2501af32e01c327f65d96b20729e4c8d6bf19318 Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Thu, 14 May 2026 11:11:44 +0530 Subject: [PATCH 1/7] Make pyaudio an optional dependency in audio_io Defer the pyaudio import to the points where it is actually needed (MicrophoneStream.__enter__, SoundCallBack.__init__, list_*_devices, get_*_info). Default WAV-output flows now work on machines without PortAudio headers installed. When pyaudio is missing, raise an ImportError that explicitly tells the user to install portaudio19-dev first, addressing the VDR finding that fresh-box users got blocked by a bare ModuleNotFoundError with no install instructions. Co-Authored-By: Claude Opus 4.7 (1M context) --- riva/client/audio_io.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/riva/client/audio_io.py b/riva/client/audio_io.py index ea432793..0618158e 100644 --- a/riva/client/audio_io.py +++ b/riva/client/audio_io.py @@ -4,7 +4,17 @@ import queue from typing import Dict, Union, Optional -import pyaudio + +def _require_pyaudio(): + try: + import pyaudio + return pyaudio + except ImportError as e: + raise ImportError( + "pyaudio is required for audio device I/O. Install the system PortAudio " + "headers first (e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, " + "`brew install portaudio` on macOS), then `pip install pyaudio`." + ) from e class MicrophoneStream: @@ -20,6 +30,8 @@ def __init__(self, rate: int, chunk: int, device: int = None) -> None: self.closed = True def __enter__(self): + pyaudio = _require_pyaudio() + self._pa_module = pyaudio self._audio_interface = pyaudio.PyAudio() self._audio_stream = self._audio_interface.open( format=pyaudio.paInt16, @@ -50,7 +62,7 @@ def __exit__(self, type, value, traceback): def _fill_buffer(self, in_data, frame_count, time_info, status_flags): """Continuously collect data from the audio stream into the buffer.""" self._buff.put(in_data) - return None, pyaudio.paContinue + return None, self._pa_module.paContinue def __next__(self) -> bytes: if self.closed: @@ -76,6 +88,7 @@ def __iter__(self): def get_audio_device_info(device_id: int) -> Dict[str, Union[int, float, str]]: + pyaudio = _require_pyaudio() p = pyaudio.PyAudio() info = p.get_device_info_by_index(device_id) p.terminate() @@ -83,6 +96,7 @@ def get_audio_device_info(device_id: int) -> Dict[str, Union[int, float, str]]: def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str]]]: + pyaudio = _require_pyaudio() p = pyaudio.PyAudio() try: info = p.get_default_input_device_info() @@ -93,6 +107,7 @@ def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str] def list_output_devices() -> None: + pyaudio = _require_pyaudio() p = pyaudio.PyAudio() print("Output audio devices:") for i in range(p.get_device_count()): @@ -104,6 +119,7 @@ def list_output_devices() -> None: def list_input_devices() -> None: + pyaudio = _require_pyaudio() p = pyaudio.PyAudio() print("Input audio devices:") for i in range(p.get_device_count()): @@ -118,6 +134,7 @@ class SoundCallBack: def __init__( self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int, ) -> None: + pyaudio = _require_pyaudio() self.pa = pyaudio.PyAudio() self.stream = self.pa.open( output_device_index=output_device_index, From 7e82b441e3012b0ddad3b364f3ab10cac8e8ee25 Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Thu, 14 May 2026 11:11:51 +0530 Subject: [PATCH 2/7] Add cli_main decorator with structured CLI exit codes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The riva-asr/nmt/tts client scripts historically exit 0 on most error paths — including "Unavailable model", connection refused, empty/invalid input, and missing files — which causes CI pipelines composing these scripts via && chains to silently swallow real failures. Add a cli_main decorator that translates uncaught exceptions into a small, consistent set of exit codes: 2 = bad input (missing/empty file, ValueError, IsADirectoryError) 3 = gRPC UNAVAILABLE (server down, wrong port, network) 4 = gRPC INVALID_ARGUMENT / NOT_FOUND (bad model/lang/voice) 1 = anything else 130 = SIGINT The decorator also writes the error to stderr so CI logs surface the cause rather than the script swallowing it. Follow-up commit wires this into each client script. Co-Authored-By: Claude Opus 4.7 (1M context) --- riva/client/argparse_utils.py | 53 +++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index a8cc1a72..ae0772b1 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -2,6 +2,59 @@ # SPDX-License-Identifier: MIT import argparse +import functools +import sys + +import grpc + +# Exit codes shared by the CLI scripts. Pipelines that compose these scripts +# rely on a non-zero status to detect failure; see also `cli_main` below. +EXIT_OK = 0 +EXIT_GENERIC_ERROR = 1 +EXIT_BAD_INPUT = 2 # malformed args, missing file, empty/whitespace text, ... +EXIT_UNAVAILABLE = 3 # gRPC UNAVAILABLE (server down, wrong port, ...) +EXIT_INVALID_ARGUMENT = 4 # gRPC INVALID_ARGUMENT or NOT_FOUND (bad model/lang/voice) +EXIT_INTERRUPTED = 130 # SIGINT + + +def _grpc_exit_code(error: grpc.RpcError) -> int: + code = error.code() if callable(getattr(error, "code", None)) else None + if code == grpc.StatusCode.UNAVAILABLE: + return EXIT_UNAVAILABLE + if code in (grpc.StatusCode.INVALID_ARGUMENT, grpc.StatusCode.NOT_FOUND): + return EXIT_INVALID_ARGUMENT + return EXIT_GENERIC_ERROR + + +def cli_main(func): + """Translate exceptions raised by a CLI ``main`` into consistent exit codes. + + Wrapped function may return an int exit code or ``None`` (treated as + ``EXIT_OK``). Unhandled exceptions are caught and mapped: gRPC ``RpcError`` + via status code, ``FileNotFoundError`` / ``ValueError`` → ``EXIT_BAD_INPUT``, + anything else → ``EXIT_GENERIC_ERROR``. The error is also printed to stderr + so CI logs surface the cause. + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + result = func(*args, **kwargs) + return EXIT_OK if result is None else int(result) + except KeyboardInterrupt: + return EXIT_INTERRUPTED + except grpc.RpcError as e: + details = e.details() if callable(getattr(e, "details", None)) else str(e) + print(f"Error: {details}", file=sys.stderr) + return _grpc_exit_code(e) + except (FileNotFoundError, IsADirectoryError, ValueError) as e: + print(f"Error: {e}", file=sys.stderr) + return EXIT_BAD_INPUT + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return EXIT_GENERIC_ERROR + + return wrapper + def validate_grpc_message_size(value): """Validate that the GRPC message size is within acceptable limits.""" From dbc639e340dd83eb05dfb16bb0ce1a5330245336 Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Thu, 14 May 2026 11:12:25 +0530 Subject: [PATCH 3/7] Wire cli_main into asr/nmt/tts client scripts and tighten input validation Address the VDR 26.02 finding that python-clients CLIs exit 0 on most error paths across all three modalities. Each script now: - Wraps main() with @cli_main so gRPC and OS errors propagate to a real exit code instead of being printed and swallowed. - Calls sys.exit(main()) so the chosen exit code reaches the shell. Script-specific fixes: scripts/nmt/nmt.py - Drop the inner request() try/except that swallowed every gRPC status; let cli_main translate it. Empty/whitespace --text and missing --text-file now return EXIT_BAD_INPUT (was: silent exit 0). Document --max-len-variation as decoder-token units with valid range [0, 256], default 20, and Arabic chunking note. scripts/tts/talk.py - Reject whitespace-only --text up front (defense-in-depth pair to the server-side fix in riva-speech that closed the hang on `--text " "`). Drop the broad `except Exception` that stringified gRPC errors and exited 0. scripts/asr/transcribe_file*.py - Replace `print(...); return` on missing input files with EXIT_BAD_INPUT. Remove the silent grpc.RpcError swallow in transcribe_file_offline.py. scripts/asr/transcribe_mic.py + realtime_asr_client.py + tts/talk.py - Pyaudio install hint now mentions `apt-get install -y portaudio19-dev` (Debian/Ubuntu) and `brew install portaudio` (macOS), pairing with the prereqs doc landed in documentation_2. scripts/tts/realtime_tts_client.py - Drop the module-level `from riva.client.audio_io import SoundCallBack` import (it was unused and pulled pyaudio in eagerly, defeating the lazy import). Drop the broad `except Exception` that mapped every failure to exit 1. scripts/nmt/nmt_speech_to_{text,speech}.py - Drop unused `import grpc`; remove the catch-all that printed "Error during translation" and exited 0. Co-Authored-By: Claude Opus 4.7 (1M context) --- scripts/asr/realtime_asr_client.py | 20 +++-- scripts/asr/riva_streaming_asr_client.py | 12 ++- scripts/asr/transcribe_file.py | 19 ++-- scripts/asr/transcribe_file_offline.py | 33 ++++--- scripts/asr/transcribe_mic.py | 24 ++++-- scripts/nmt/nmt.py | 77 +++++++++-------- scripts/nmt/nmt_speech_to_speech.py | 105 +++++++++++------------ scripts/nmt/nmt_speech_to_text.py | 89 ++++++++++--------- scripts/tts/realtime_tts_client.py | 45 +++++----- scripts/tts/talk.py | 38 ++++---- 10 files changed, 251 insertions(+), 211 deletions(-) diff --git a/scripts/asr/realtime_asr_client.py b/scripts/asr/realtime_asr_client.py index 172ec9b2..0ffd1158 100644 --- a/scripts/asr/realtime_asr_client.py +++ b/scripts/asr/realtime_asr_client.py @@ -12,6 +12,7 @@ add_asr_config_argparse_parameters, add_realtime_config_argparse_parameters, add_connection_argparse_parameters, + cli_main, ) @@ -300,17 +301,22 @@ async def main() -> None: import riva.client.audio_io riva.client.audio_io.list_input_devices() except ModuleNotFoundError: - print("PyAudio not available. Please install PyAudio to list audio devices.") + print( + "PyAudio not available. Install the system PortAudio headers first " + "(e.g. `apt-get install -y portaudio19-dev`), then `pip install pyaudio`.", + file=sys.stderr, + ) return setup_signal_handler() + await run_transcription(args) - try: - await run_transcription(args) - except Exception as e: - print(f"Fatal error: {e}") - sys.exit(1) + +@cli_main +def _entry() -> int: + asyncio.run(main()) + return 0 if __name__ == "__main__": - asyncio.run(main()) + sys.exit(_entry()) diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index f600af66..08a6ea22 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -4,6 +4,7 @@ import argparse import os import queue +import sys import time from pathlib import Path from threading import Thread @@ -11,7 +12,11 @@ import riva.client from riva.client.asr import get_wav_file_parameters -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_asr_config_argparse_parameters, + add_connection_argparse_parameters, + cli_main, +) def parse_args() -> argparse.Namespace: @@ -109,7 +114,8 @@ def streaming_transcription_worker( raise -def main() -> None: +@cli_main +def main() -> int: args = parse_args() print("Number of clients:", args.num_clients) print("Number of iteration:", args.num_iterations) @@ -140,4 +146,4 @@ def main() -> None: if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index 1849a675..8c3f6c2e 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -2,10 +2,16 @@ # SPDX-License-Identifier: MIT import argparse - import os +import sys + import riva.client -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_asr_config_argparse_parameters, + add_connection_argparse_parameters, + cli_main, + EXIT_BAD_INPUT, +) def parse_args() -> argparse.Namespace: @@ -61,7 +67,8 @@ def parse_args() -> argparse.Namespace: return args -def main() -> None: +@cli_main +def main() -> int: args = parse_args() if args.list_devices: riva.client.audio_io.list_output_devices() @@ -95,8 +102,8 @@ def main() -> None: return if not os.path.isfile(args.input_file): - print(f"Invalid input file path: {args.input_file}") - return + print(f"Invalid input file path: {args.input_file}", file=sys.stderr) + return EXIT_BAD_INPUT config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( @@ -159,4 +166,4 @@ def main() -> None: if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 92634caa..4c2767aa 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -2,12 +2,17 @@ # SPDX-License-Identifier: MIT import os +import sys import argparse from pathlib import Path -import grpc import riva.client -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_asr_config_argparse_parameters, + add_connection_argparse_parameters, + cli_main, + EXIT_BAD_INPUT, +) def parse_args() -> argparse.Namespace: @@ -30,7 +35,8 @@ def parse_args() -> argparse.Namespace: return args -def main() -> None: +@cli_main +def main() -> int: args = parse_args() options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] @@ -62,8 +68,8 @@ def main() -> None: return if not os.path.isfile(args.input_file): - print(f"Invalid input file path: {args.input_file}") - return + print(f"Invalid input file path: {args.input_file}", file=sys.stderr) + return EXIT_BAD_INPUT config = riva.client.RecognitionConfig( language_code=args.language_code, @@ -91,14 +97,15 @@ def main() -> None: ) with args.input_file.open('rb') as fh: data = fh.read() - try: - seglst_output_file = None - if args.output_seglst: - seglst_output_file = os.path.basename(args.input_file).split(".")[0] - riva.client.print_offline(response=asr_service.offline_recognize(data, config), speaker_diarization=args.speaker_diarization, seglst_output_file=seglst_output_file) - except grpc.RpcError as e: - print(e.details()) + seglst_output_file = None + if args.output_seglst: + seglst_output_file = os.path.basename(args.input_file).split(".")[0] + riva.client.print_offline( + response=asr_service.offline_recognize(data, config), + speaker_diarization=args.speaker_diarization, + seglst_output_file=seglst_output_file, + ) if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index 3fd2b5a2..a5d63265 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -2,16 +2,27 @@ # SPDX-License-Identifier: MIT import argparse +import sys import riva.client -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_asr_config_argparse_parameters, + add_connection_argparse_parameters, + cli_main, + EXIT_BAD_INPUT, +) try: import riva.client.audio_io except ModuleNotFoundError as e: - print(f"ModuleNotFoundError: {e}") - print("Please install pyaudio from https://pypi.org/project/PyAudio") - exit(1) + print(f"ModuleNotFoundError: {e}", file=sys.stderr) + print( + "Install the system PortAudio headers first " + "(e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, " + "`brew install portaudio` on macOS), then `pip install pyaudio`.", + file=sys.stderr, + ) + sys.exit(EXIT_BAD_INPUT) def parse_args() -> argparse.Namespace: default_device_info = riva.client.audio_io.get_default_input_device_info() @@ -40,7 +51,8 @@ def parse_args() -> argparse.Namespace: return args -def main() -> None: +@cli_main +def main() -> int: args = parse_args() if args.list_devices: riva.client.audio_io.list_input_devices() @@ -98,4 +110,4 @@ def main() -> None: if __name__ == '__main__': - main() + sys.exit(main()) diff --git a/scripts/nmt/nmt.py b/scripts/nmt/nmt.py index f3d13dbe..d77a49d1 100644 --- a/scripts/nmt/nmt.py +++ b/scripts/nmt/nmt.py @@ -30,12 +30,11 @@ import os import sys -import grpc import riva.client.proto.riva_nmt_pb2 as riva_nmt import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters +from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main, EXIT_BAD_INPUT def read_dnt_phrases_file(file_path): @@ -78,7 +77,15 @@ def parse_args() -> argparse.Namespace: ) inputs.add_argument("--text-file", type=str, help="Path to file for translation") parser.add_argument("--dnt-phrases-file", type=str, help="Path to file which contains dnt phrases and custom translations") - parser.add_argument("--max-len-variation", type=str, help="Parameter to control the maximum variation between the length of source and translated text in terms of tokens") + parser.add_argument( + "--max-len-variation", + type=str, + help="Maximum allowed difference (in decoder SentencePiece tokens, not characters) " + "between the source and translated text length. Valid range: [0, 256]. Server-side " + "default is 20. Increase this for long inputs that get truncated; high-aspect-ratio " + "languages like Arabic may need additional client-side chunking when the source " + "exceeds ~200 characters even at 256.", + ) parser.add_argument("--model-name", default="", type=str, help="model to use to translate") parser.add_argument( "--source-language-code", type=str, default="en-US", help="Source language code (according to BCP-47 standard)" @@ -93,34 +100,8 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def main() -> None: - def request(inputs,args): - try: - dnt_phrases_input = {} - if args.dnt_phrases_file != None: - dnt_phrases_input = read_dnt_phrases_file(args.dnt_phrases_file) - response = nmt_client.translate( - texts=inputs, - model=args.model_name, - source_language=args.source_language_code, - target_language=args.target_language_code, - future=False, - dnt_phrases_dict=dnt_phrases_input, - max_len_variation=args.max_len_variation, - ) - for translation in response.translations: - print(translation.text) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.INVALID_ARGUMENT: - result = {'msg': 'invalid arg error'} - elif e.code() == grpc.StatusCode.ALREADY_EXISTS: - result = {'msg': 'already exists error'} - elif e.code() == grpc.StatusCode.UNAVAILABLE: - result = {'msg': 'server unavailable check network'} - else: - result = {'msg': 'error code:{}'.format(e.code())} - print(f"{result['msg']} : {e.details()}") - +@cli_main +def main() -> int: args = parse_args() auth = riva.client.Auth( @@ -134,13 +115,31 @@ def request(inputs,args): ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) - if args.list_models: + def request(inputs): + dnt_phrases_input = {} + if args.dnt_phrases_file is not None: + dnt_phrases_input = read_dnt_phrases_file(args.dnt_phrases_file) + response = nmt_client.translate( + texts=inputs, + model=args.model_name, + source_language=args.source_language_code, + target_language=args.target_language_code, + future=False, + dnt_phrases_dict=dnt_phrases_input, + max_len_variation=args.max_len_variation, + ) + for translation in response.translations: + print(translation.text) + if args.list_models: response = nmt_client.get_config(args.model_name) print(response) return - if args.text_file != None and os.path.exists(args.text_file): + if args.text_file is not None: + if not os.path.exists(args.text_file): + print(f"Invalid input file path: {args.text_file}", file=sys.stderr) + return EXIT_BAD_INPUT with open(args.text_file, "r") as f: batch = [] for line in f: @@ -148,15 +147,17 @@ def request(inputs,args): if line != "": batch.append(line) if len(batch) == args.batch_size: - request(batch, args) + request(batch) batch = [] if len(batch) > 0: - request(batch, args) + request(batch) return - if args.text != "": - request([args.text], args) + if not args.text or not args.text.strip(): + print("No input text provided", file=sys.stderr) + return EXIT_BAD_INPUT + request([args.text]) if __name__ == '__main__': - main() + sys.exit(main()) diff --git a/scripts/nmt/nmt_speech_to_speech.py b/scripts/nmt/nmt_speech_to_speech.py index 7348525b..20f711d9 100644 --- a/scripts/nmt/nmt_speech_to_speech.py +++ b/scripts/nmt/nmt_speech_to_speech.py @@ -1,14 +1,13 @@ import argparse import os +import sys import wave from typing import Iterator -import grpc - import riva.client import riva.client.proto.riva_asr_pb2 as riva_asr_pb2 import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2 -from riva.client.argparse_utils import add_connection_argparse_parameters +from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main def parse_arguments(): @@ -25,6 +24,7 @@ def parse_arguments(): return parser.parse_args() +@cli_main def main(): args = parse_arguments() @@ -48,59 +48,54 @@ def main(): print(response) return + print(f"Translating speech from {args.source_language} to {args.target_language}") + print(f"Using audio file: {args.audio_file}") + print(f"Server address: {args.server}") + + # Create ASR config + asr_config = riva_asr_pb2.StreamingRecognitionConfig( + config=riva_asr_pb2.RecognitionConfig( + language_code=args.source_language, max_alternatives=1, enable_automatic_punctuation=True + ), + interim_results=True, + ) + + # Create translation config + translation_config = riva_nmt_pb2.TranslationConfig( + source_language_code=args.source_language, target_language_code=args.target_language, + ) + + # Create synthesis config + tts_config = riva_nmt_pb2.SynthesizeSpeechConfig( + encoding=riva.client.AudioEncoding.LINEAR_PCM, + language_code=args.target_language, + voice_name=args.voice, + sample_rate_hz=args.sample_rate_hz, + ) + + # Create streaming config + streaming_config = riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig( + asr_config=asr_config, translation_config=translation_config, tts_config=tts_config + ) + + responses = nmt_client.streaming_s2s_response_generator( + audio_chunks=riva.client.AudioChunkFileIterator(args.audio_file, 100), streaming_config=streaming_config + ) + + output_file = None try: - print(f"Translating speech from {args.source_language} to {args.target_language}") - print(f"Using audio file: {args.audio_file}") - print(f"Server address: {args.server}") - - # Create ASR config - asr_config = riva_asr_pb2.StreamingRecognitionConfig( - config=riva_asr_pb2.RecognitionConfig( - language_code=args.source_language, max_alternatives=1, enable_automatic_punctuation=True - ), - interim_results=True, - ) - - # Create translation config - translation_config = riva_nmt_pb2.TranslationConfig( - source_language_code=args.source_language, target_language_code=args.target_language, - ) - - # Create synthesis config - tts_config = riva_nmt_pb2.SynthesizeSpeechConfig( - encoding=riva.client.AudioEncoding.LINEAR_PCM, - language_code=args.target_language, - voice_name=args.voice, - sample_rate_hz=args.sample_rate_hz, - ) - - # Create streaming config - streaming_config = riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig( - asr_config=asr_config, translation_config=translation_config, tts_config=tts_config - ) - - responses = nmt_client.streaming_s2s_response_generator( - audio_chunks=riva.client.AudioChunkFileIterator(args.audio_file, 100), streaming_config=streaming_config - ) - - try: - output_file = None - for response in responses: - if len(response.speech.audio) > 0 and output_file is None: - output_file = wave.open(str(args.output_file), 'wb') - output_file.setnchannels(1) - output_file.setsampwidth(2) - output_file.setframerate(args.sample_rate_hz) - output_file.writeframesraw(response.speech.audio) - - finally: - if output_file is not None: - print(f"Written {output_file.getnframes()} samples to {args.output_file}") - output_file.close() - - except Exception as e: - print(f"Error during translation: {e}") + for response in responses: + if len(response.speech.audio) > 0 and output_file is None: + output_file = wave.open(str(args.output_file), 'wb') + output_file.setnchannels(1) + output_file.setsampwidth(2) + output_file.setframerate(args.sample_rate_hz) + output_file.writeframesraw(response.speech.audio) + finally: + if output_file is not None: + print(f"Written {output_file.getnframes()} samples to {args.output_file}") + output_file.close() if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/nmt/nmt_speech_to_text.py b/scripts/nmt/nmt_speech_to_text.py index bc8c0f2a..66ea3da0 100644 --- a/scripts/nmt/nmt_speech_to_text.py +++ b/scripts/nmt/nmt_speech_to_text.py @@ -1,9 +1,11 @@ import argparse import os +import sys + import riva.client import riva.client.proto.riva_asr_pb2 as riva_asr_pb2 import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2 -from riva.client.argparse_utils import add_connection_argparse_parameters +from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main def parse_arguments(): parser = argparse.ArgumentParser(description='Riva Speech-to-Text Translation Client') @@ -36,6 +38,7 @@ def parse_arguments(): return parser.parse_args() +@cli_main def main(): args = parse_arguments() @@ -59,50 +62,46 @@ def main(): print(response) return - try: - print(f"Translating speech from {args.source_language} to {args.target_language}") - print(f"Using audio file: {args.audio_file}") - print(f"Server address: {args.server}") - - # Create ASR config - asr_config = riva_asr_pb2.StreamingRecognitionConfig( - config=riva_asr_pb2.RecognitionConfig( - language_code=args.source_language, - max_alternatives=1, - enable_automatic_punctuation=True - ), - interim_results=True - ) - - # Create translation config - translation_config = riva_nmt_pb2.TranslationConfig( - source_language_code=args.source_language, - target_language_code=args.target_language, - model_name=args.model - ) - - # Create streaming config - streaming_config = riva_nmt_pb2.StreamingTranslateSpeechToTextConfig( - asr_config=asr_config, - translation_config=translation_config - ) - - responses = nmt_client.streaming_s2t_response_generator( - audio_chunks=riva.client.AudioChunkFileIterator(args.audio_file, 100), - streaming_config=streaming_config - ) - - final_translation = "" - for response in responses: - for result in response.results: - if result.is_final: - final_translation += result.alternatives[0].transcript - - print(f"Final translation: {final_translation}") - - except Exception as e: - print(f"Error during translation: {e}") + print(f"Translating speech from {args.source_language} to {args.target_language}") + print(f"Using audio file: {args.audio_file}") + print(f"Server address: {args.server}") + + # Create ASR config + asr_config = riva_asr_pb2.StreamingRecognitionConfig( + config=riva_asr_pb2.RecognitionConfig( + language_code=args.source_language, + max_alternatives=1, + enable_automatic_punctuation=True + ), + interim_results=True + ) + + # Create translation config + translation_config = riva_nmt_pb2.TranslationConfig( + source_language_code=args.source_language, + target_language_code=args.target_language, + model_name=args.model + ) + + # Create streaming config + streaming_config = riva_nmt_pb2.StreamingTranslateSpeechToTextConfig( + asr_config=asr_config, + translation_config=translation_config + ) + + responses = nmt_client.streaming_s2t_response_generator( + audio_chunks=riva.client.AudioChunkFileIterator(args.audio_file, 100), + streaming_config=streaming_config + ) + + final_translation = "" + for response in responses: + for result in response.results: + if result.is_final: + final_translation += result.alternatives[0].transcript + + print(f"Final translation: {final_translation}") if __name__ == "__main__": - main() \ No newline at end of file + sys.exit(main()) \ No newline at end of file diff --git a/scripts/tts/realtime_tts_client.py b/scripts/tts/realtime_tts_client.py index 096c0dc0..5277c8bd 100644 --- a/scripts/tts/realtime_tts_client.py +++ b/scripts/tts/realtime_tts_client.py @@ -20,8 +20,7 @@ import websockets from websockets.exceptions import WebSocketException -from riva.client.argparse_utils import add_connection_argparse_parameters -from riva.client.audio_io import SoundCallBack +from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main from riva.client.realtime import RealtimeClientTTS @@ -478,30 +477,30 @@ async def process_single_text(text_idx, text_line): async def main() -> int: """Main entry point for the realtime TTS client.""" args = parse_args() - success = False + success = True setup_signal_handler() - try: - if args.list_voices: - voices = RealtimeClientTTS(args=args).list_voices() - print(json.dumps(voices, indent=4)) - elif args.list_devices: - import riva.client.audio_io - riva.client.audio_io.list_output_devices() + if args.list_voices: + voices = RealtimeClientTTS(args=args).list_voices() + print(json.dumps(voices, indent=4)) + elif args.list_devices: + import riva.client.audio_io + riva.client.audio_io.list_output_devices() + else: + # Use parallel processing if num_parallel_requests > 1 + if args.num_parallel_requests > 1: + logger.info(f"Using parallel processing mode with {args.num_parallel_requests} concurrent requests") + success = await run_parallel_synthesis(args) else: - # Use parallel processing if num_parallel_requests > 1 - if args.num_parallel_requests > 1: - logger.info(f"Using parallel processing mode with {args.num_parallel_requests} concurrent requests") - success = await run_parallel_synthesis(args) - else: - logger.info("Using single request mode") - success = await run_synthesis(args) - return 0 if success else 1 - except Exception as e: - logger.error("Fatal error: %s", e) - return 1 + logger.info("Using single request mode") + success = await run_synthesis(args) + return 0 if success else 1 + + +@cli_main +def _entry() -> int: + return asyncio.run(main()) if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) + sys.exit(_entry()) diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index 9ac0ea57..d12d45eb 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -2,13 +2,18 @@ # SPDX-License-Identifier: MIT import argparse +import sys import time import wave import json from pathlib import Path import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_connection_argparse_parameters, + cli_main, + EXIT_BAD_INPUT, +) from riva.client.proto.riva_audio_pb2 import AudioEncoding from riva.client.tts import parse_custom_configuration @@ -94,16 +99,21 @@ def parse_args() -> argparse.Namespace: import riva.client.audio_io except ModuleNotFoundError as e: print(f"ModuleNotFoundError: {e}") - print("Please install pyaudio from https://pypi.org/project/PyAudio") + print( + "Install the system PortAudio headers first " + "(e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, " + "`brew install portaudio` on macOS), then `pip install pyaudio`." + ) exit(1) return args -def main() -> None: +@cli_main +def main() -> int: args = parse_args() if args.output.is_dir(): - print("Empty output file path not allowed") - return + print("Empty output file path not allowed", file=sys.stderr) + return EXIT_BAD_INPUT if args.list_devices: riva.client.audio_io.list_output_devices() return @@ -148,11 +158,14 @@ def main() -> None: return if not args.text and not args.text_file: - print("No input text provided") - return + print("No input text provided", file=sys.stderr) + return EXIT_BAD_INPUT + if args.text is not None and not args.text.strip(): + print("No input text provided", file=sys.stderr) + return EXIT_BAD_INPUT if args.text and args.text_file: - print("Cannot provide both text and text_file at the same time.") - return + print("Cannot provide both text and text_file at the same time.", file=sys.stderr) + return EXIT_BAD_INPUT try: if args.output_device is not None or args.play_audio: sound_stream = riva.client.audio_io.SoundCallBack( @@ -218,11 +231,6 @@ def main() -> None: sound_stream(resp.audio) if out_f is not None: out_f.writeframesraw(resp.audio) - except Exception as e: - if callable(getattr(e, "details", None)): - print(e.details()) - else: - print(e) finally: if out_f is not None: out_f.close() @@ -231,4 +239,4 @@ def main() -> None: if __name__ == '__main__': - main() + sys.exit(main()) From 87450d8f49102f851ce2e570ea51d7ba9cf2e141 Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Fri, 15 May 2026 10:45:38 +0530 Subject: [PATCH 4/7] Send override-only payload for realtime TTS session update MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit VDR 26.02 found that realtime_tts_client.py silently ignored --voice and fell back to the server default (Mia). Tracing the WebSocket flow, the synthesize_session.update payload was built by deep-mutating the response from POST /v1/realtime/synthesis_sessions — an InitialSynthesisSessionConfig that carries id/object/client_secret fields not present in BaseSynthesisSessionConfig (the type the server validates the update against). Carrying those keys through to the override, plus the shallow .copy() + _safe_update_config nested-dict mutation, was the path that let the voice_name override fail to land on published 26.02 NIMs. Build the update payload explicitly from CLI args instead, so only fields the user actually overrode reach the server, in the exact shape documented in the SynthesisSessionUpdateMessage schema. Bump the override summary to INFO so users can see which fields were sent. After the synthesize_session.updated response, compare the server-applied voice_name and language_code against what was requested and log a WARNING on mismatch — defense-in-depth so any future server-side drop surfaces in the client log instead of as a wrong-sounding audio file. Co-Authored-By: Claude Opus 4.7 (1M context) --- riva/client/realtime.py | 113 ++++++++++++++++++++++++++-------------- 1 file changed, 75 insertions(+), 38 deletions(-) diff --git a/riva/client/realtime.py b/riva/client/realtime.py index 98c616ee..a72e4770 100644 --- a/riva/client/realtime.py +++ b/riva/client/realtime.py @@ -7,7 +7,7 @@ import logging import queue import uuid -from typing import Dict, Any, Generator +from typing import Any, Dict, Generator, List, Optional import requests import websockets @@ -677,73 +677,94 @@ def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, sect logger.debug("Updated %s = %s", key, value) async def _update_session(self, timeout=1): - """Update session configuration by selectively overriding server defaults.""" + """Update session configuration by sending an override-only payload. + + Builds the synthesize_session.update payload directly from CLI args + instead of round-tripping through self.session_config (the response + from POST /v1/realtime/synthesis_sessions). The HTTP response is an + InitialSynthesisSessionConfig and includes id/object/client_secret + fields that are not part of BaseSynthesisSessionConfig (the type the + server expects on the update). Carrying those into the override + payload — and the resulting deep-mutation through _safe_update_config + — was the root cause of --voice silently failing to take effect + against published 26.02 NIMs. + """ logger.info("Updating session configuration...") logger.debug("Server default config: %s", self.session_config) - # Create a copy of the session config from server defaults - session_config = self.session_config.copy() - - # Track what we're overriding - overrides = [] + session_payload: Dict[str, Any] = {} + overrides: List[str] = [] + requested_voice: Optional[str] = None + requested_language: Optional[str] = None - # Update input text synthesis - only override if args are provided - if hasattr(self.args, 'language_code') and self.args.language_code: - self._safe_update_config(session_config, "language_code", self.args.language_code, "input_text_synthesis") + # input_text_synthesis: language_code + voice_name + input_text_synthesis: Dict[str, Any] = {} + if getattr(self.args, "language_code", None): + requested_language = self.args.language_code + input_text_synthesis["language_code"] = requested_language overrides.append("language_code") - - if hasattr(self.args, 'voice') and self.args.voice: - self._safe_update_config(session_config, "voice_name", self.args.voice, "input_text_synthesis") + if getattr(self.args, "voice", None): + requested_voice = self.args.voice + input_text_synthesis["voice_name"] = requested_voice overrides.append("voice_name") + if input_text_synthesis: + session_payload["input_text_synthesis"] = input_text_synthesis - # Update output audio parameters - only override if args are provided - if hasattr(self.args, 'sample_rate_hz') and self.args.sample_rate_hz: - self._safe_update_config(session_config, "sample_rate_hz", self.args.sample_rate_hz, "output_audio_params") + # output_audio_params: sample_rate_hz + audio_format + output_audio_params: Dict[str, Any] = {} + if getattr(self.args, "sample_rate_hz", None): + output_audio_params["sample_rate_hz"] = self.args.sample_rate_hz overrides.append("sample_rate_hz") - - if hasattr(self.args, 'encoding') and self.args.encoding: - self._safe_update_config(session_config, "audio_format", self.args.encoding, "output_audio_params") + if getattr(self.args, "encoding", None): + output_audio_params["audio_format"] = self.args.encoding overrides.append("audio_format") + if output_audio_params: + session_payload["output_audio_params"] = output_audio_params - # Update custom dictionary - only override if args are provided - if hasattr(self.args, 'custom_dictionary') and self.args.custom_dictionary: - self._safe_update_config(session_config, "custom_dictionary", self.args.custom_dictionary) + if getattr(self.args, "custom_dictionary", None): + session_payload["custom_dictionary"] = self.args.custom_dictionary overrides.append("custom_dictionary") - # Update zero-shot config - only override if args are provided - if (hasattr(self.args, 'zero_shot_audio_prompt_file') and self.args.zero_shot_audio_prompt_file): + # zero_shot_config: audio bytes + transcript + quality + if getattr(self.args, "zero_shot_audio_prompt_file", None): + zero_shot_config: Dict[str, Any] = {} try: - with open(self.args.zero_shot_audio_prompt_file, 'rb') as f: + with open(self.args.zero_shot_audio_prompt_file, "rb") as f: audio_data = f.read() - base64_audio_data = base64.b64encode(audio_data).decode('utf-8') - self._safe_update_config(session_config["zero_shot_config"], "audio_prompt_bytes", base64_audio_data) - logger.info("Zero-shot audio prompt bytes: %s", len(base64_audio_data)) + base64_audio_data = base64.b64encode(audio_data).decode("utf-8") + zero_shot_config["audio_prompt_bytes"] = base64_audio_data + logger.info("Zero-shot audio prompt bytes: %s", len(base64_audio_data)) overrides.append("zero_shot_audio_prompt_file") except Exception as e: logger.warning("Failed to load zero-shot audio prompt: %s", e) - - if hasattr(self.args, 'zero_shot_audio_prompt_transcript') and self.args.zero_shot_audio_prompt_transcript: - self._safe_update_config(session_config["zero_shot_config"], "audio_prompt_transcript", self.args.zero_shot_audio_prompt_transcript) + + if getattr(self.args, "zero_shot_audio_prompt_transcript", None): + zero_shot_config["audio_prompt_transcript"] = self.args.zero_shot_audio_prompt_transcript logger.info("Zero-shot audio prompt transcript: %s", self.args.zero_shot_audio_prompt_transcript) overrides.append("zero_shot_transcript") - - if hasattr(self.args, 'zero_shot_prompt_quality') and self.args.zero_shot_prompt_quality: - self._safe_update_config(session_config["zero_shot_config"], "prompt_quality", self.args.zero_shot_prompt_quality) + + if getattr(self.args, "zero_shot_prompt_quality", None): + zero_shot_config["prompt_quality"] = self.args.zero_shot_prompt_quality logger.info("Zero-shot quality: %s", self.args.zero_shot_prompt_quality) overrides.append("zero_shot_prompt_quality") - if hasattr(self.args, 'custom_configuration') and self.args.custom_configuration: + if zero_shot_config: + session_payload["zero_shot_config"] = zero_shot_config + + if getattr(self.args, "custom_configuration", None): custom_config = self._parse_custom_configuration(self.args.custom_configuration) if custom_config: - session_config["custom_configuration"] = custom_config + session_payload["custom_configuration"] = custom_config overrides.append("custom_configuration") - logger.debug("Overriding parameters: %s", overrides) + logger.info("Overriding session parameters: %s", overrides) + if requested_voice: + logger.info("Requested voice_name=%r", requested_voice) update_request = { "event_id": f"event_{uuid.uuid4()}", "type": "synthesize_session.update", - "session": session_config + "session": session_payload, } await self._send_message(update_request) @@ -764,6 +785,22 @@ async def _update_session(self, timeout=1): logger.info("Synthesis session updated successfully") self.session_config = response_data["session"] session_updated = True + acked = self.session_config.get("input_text_synthesis", {}) if isinstance(self.session_config, dict) else {} + acked_voice = acked.get("voice_name") + acked_language = acked.get("language_code") + if requested_voice and acked_voice and acked_voice != requested_voice: + logger.warning( + "Server applied voice_name=%r, but --voice requested %r. " + "Synthesis will use the server-applied voice.", + acked_voice, requested_voice, + ) + elif requested_voice: + logger.info("Server confirmed voice_name=%r", acked_voice) + if requested_language and acked_language and acked_language != requested_language: + logger.warning( + "Server applied language_code=%r, but --language-code requested %r.", + acked_language, requested_language, + ) elif event_type == "error": error_info = response_data.get("error", {}) logger.error("Error: %s", error_info.get("message", "Unknown error")) From 3bb8c8c7f13f4ad52a344397b38385ab787f1c17 Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Mon, 18 May 2026 18:57:43 +0530 Subject: [PATCH 5/7] Guard TTS custom_configuration usage for backwards compatibility Only import parse_custom_configuration and pass custom_configuration to synthesize/synthesize_online when --custom-configuration is supplied, so talk.py keeps working against older riva-client wheels that lack the function and the kwarg. Co-Authored-By: Claude Opus 4.7 (1M context) --- scripts/tts/talk.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index d12d45eb..36eef086 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -15,7 +15,6 @@ EXIT_BAD_INPUT, ) from riva.client.proto.riva_audio_pb2 import AudioEncoding -from riva.client.tts import parse_custom_configuration def read_file_to_dict(file_path): result_dict = {} @@ -192,7 +191,10 @@ def main() -> int: if args.custom_dictionary is not None: custom_dictionary_input = read_file_to_dict(args.custom_dictionary) - custom_configuration_input = parse_custom_configuration(args.custom_configuration) + custom_configuration_kwargs = {} + if args.custom_configuration: + from riva.client.tts import parse_custom_configuration + custom_configuration_kwargs['custom_configuration'] = parse_custom_configuration(args.custom_configuration) print("Generating audio for request...") start = time.time() @@ -203,7 +205,7 @@ def main() -> int: zero_shot_audio_prompt_file=args.zero_shot_audio_prompt_file, zero_shot_quality=(20 if args.zero_shot_quality is None else args.zero_shot_quality), custom_dictionary=custom_dictionary_input, - custom_configuration=custom_configuration_input, + **custom_configuration_kwargs, ) first = True for resp in responses: @@ -223,7 +225,7 @@ def main() -> int: zero_shot_quality=(20 if args.zero_shot_quality is None else args.zero_shot_quality), custom_dictionary=custom_dictionary_input, zero_shot_transcript=args.zero_shot_transcript, - custom_configuration=custom_configuration_input, + **custom_configuration_kwargs, ) stop = time.time() print(f"Time spent: {(stop - start):.3f}s") From 19ff10576ebfe86a9fa27b071310a2b1cb208d9d Mon Sep 17 00:00:00 2001 From: Anurag Tomer Date: Mon, 18 May 2026 19:08:31 +0530 Subject: [PATCH 6/7] Guard cli_main/EXIT_BAD_INPUT imports for backwards compatibility cli_main and EXIT_BAD_INPUT were added recently in argparse_utils and are not present in older riva-client wheels. Wrap their imports in a try/except across all asr/nmt/tts client scripts, falling back to a no-op decorator and EXIT_BAD_INPUT=2 so the scripts keep running against older installed wheels (only the structured exit codes are lost in that case). Co-Authored-By: Claude Opus 4.7 (1M context) --- scripts/asr/realtime_asr_client.py | 6 +++++- scripts/asr/riva_streaming_asr_client.py | 6 +++++- scripts/asr/transcribe_file.py | 8 ++++++-- scripts/asr/transcribe_file_offline.py | 8 ++++++-- scripts/asr/transcribe_mic.py | 8 ++++++-- scripts/nmt/nmt.py | 8 +++++++- scripts/nmt/nmt_speech_to_speech.py | 7 ++++++- scripts/nmt/nmt_speech_to_text.py | 7 ++++++- scripts/tts/realtime_tts_client.py | 7 ++++++- scripts/tts/talk.py | 12 +++++++----- 10 files changed, 60 insertions(+), 17 deletions(-) diff --git a/scripts/asr/realtime_asr_client.py b/scripts/asr/realtime_asr_client.py index 0ffd1158..9c6a5bd1 100644 --- a/scripts/asr/realtime_asr_client.py +++ b/scripts/asr/realtime_asr_client.py @@ -12,8 +12,12 @@ add_asr_config_argparse_parameters, add_realtime_config_argparse_parameters, add_connection_argparse_parameters, - cli_main, ) +try: + from riva.client.argparse_utils import cli_main +except ImportError: + def cli_main(func): + return func def parse_args() -> argparse.Namespace: diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 08a6ea22..95fb3c99 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -15,8 +15,12 @@ from riva.client.argparse_utils import ( add_asr_config_argparse_parameters, add_connection_argparse_parameters, - cli_main, ) +try: + from riva.client.argparse_utils import cli_main +except ImportError: + def cli_main(func): + return func def parse_args() -> argparse.Namespace: diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index 8c3f6c2e..f474469a 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -9,9 +9,13 @@ from riva.client.argparse_utils import ( add_asr_config_argparse_parameters, add_connection_argparse_parameters, - cli_main, - EXIT_BAD_INPUT, ) +try: + from riva.client.argparse_utils import cli_main, EXIT_BAD_INPUT +except ImportError: + EXIT_BAD_INPUT = 2 + def cli_main(func): + return func def parse_args() -> argparse.Namespace: diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 4c2767aa..f637b0b6 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -10,9 +10,13 @@ from riva.client.argparse_utils import ( add_asr_config_argparse_parameters, add_connection_argparse_parameters, - cli_main, - EXIT_BAD_INPUT, ) +try: + from riva.client.argparse_utils import cli_main, EXIT_BAD_INPUT +except ImportError: + EXIT_BAD_INPUT = 2 + def cli_main(func): + return func def parse_args() -> argparse.Namespace: diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index a5d63265..320562dd 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -8,9 +8,13 @@ from riva.client.argparse_utils import ( add_asr_config_argparse_parameters, add_connection_argparse_parameters, - cli_main, - EXIT_BAD_INPUT, ) +try: + from riva.client.argparse_utils import cli_main, EXIT_BAD_INPUT +except ImportError: + EXIT_BAD_INPUT = 2 + def cli_main(func): + return func try: import riva.client.audio_io diff --git a/scripts/nmt/nmt.py b/scripts/nmt/nmt.py index d77a49d1..6bd035e3 100644 --- a/scripts/nmt/nmt.py +++ b/scripts/nmt/nmt.py @@ -34,7 +34,13 @@ import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main, EXIT_BAD_INPUT +from riva.client.argparse_utils import add_connection_argparse_parameters +try: + from riva.client.argparse_utils import cli_main, EXIT_BAD_INPUT +except ImportError: + EXIT_BAD_INPUT = 2 + def cli_main(func): + return func def read_dnt_phrases_file(file_path): diff --git a/scripts/nmt/nmt_speech_to_speech.py b/scripts/nmt/nmt_speech_to_speech.py index 20f711d9..d06f5819 100644 --- a/scripts/nmt/nmt_speech_to_speech.py +++ b/scripts/nmt/nmt_speech_to_speech.py @@ -7,7 +7,12 @@ import riva.client import riva.client.proto.riva_asr_pb2 as riva_asr_pb2 import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2 -from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main +from riva.client.argparse_utils import add_connection_argparse_parameters +try: + from riva.client.argparse_utils import cli_main +except ImportError: + def cli_main(func): + return func def parse_arguments(): diff --git a/scripts/nmt/nmt_speech_to_text.py b/scripts/nmt/nmt_speech_to_text.py index 66ea3da0..dbbd7607 100644 --- a/scripts/nmt/nmt_speech_to_text.py +++ b/scripts/nmt/nmt_speech_to_text.py @@ -5,7 +5,12 @@ import riva.client import riva.client.proto.riva_asr_pb2 as riva_asr_pb2 import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2 -from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main +from riva.client.argparse_utils import add_connection_argparse_parameters +try: + from riva.client.argparse_utils import cli_main +except ImportError: + def cli_main(func): + return func def parse_arguments(): parser = argparse.ArgumentParser(description='Riva Speech-to-Text Translation Client') diff --git a/scripts/tts/realtime_tts_client.py b/scripts/tts/realtime_tts_client.py index 5277c8bd..3144450b 100644 --- a/scripts/tts/realtime_tts_client.py +++ b/scripts/tts/realtime_tts_client.py @@ -20,7 +20,12 @@ import websockets from websockets.exceptions import WebSocketException -from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main +from riva.client.argparse_utils import add_connection_argparse_parameters +try: + from riva.client.argparse_utils import cli_main +except ImportError: + def cli_main(func): + return func from riva.client.realtime import RealtimeClientTTS diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index 36eef086..d5469e2e 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -9,11 +9,13 @@ from pathlib import Path import riva.client -from riva.client.argparse_utils import ( - add_connection_argparse_parameters, - cli_main, - EXIT_BAD_INPUT, -) +from riva.client.argparse_utils import add_connection_argparse_parameters +try: + from riva.client.argparse_utils import cli_main, EXIT_BAD_INPUT +except ImportError: + EXIT_BAD_INPUT = 2 + def cli_main(func): + return func from riva.client.proto.riva_audio_pb2 import AudioEncoding def read_file_to_dict(file_path): From 929ecb19a18a7907505358bd9c03d4aee77efc0f Mon Sep 17 00:00:00 2001 From: Yuvaraj Dharavath Date: Wed, 20 May 2026 07:13:51 +0000 Subject: [PATCH 7/7] fix: nmt: surface EXIT_BAD_INPUT when --text-file has no non-empty lines --- scripts/nmt/nmt.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/scripts/nmt/nmt.py b/scripts/nmt/nmt.py index 6bd035e3..85861efc 100644 --- a/scripts/nmt/nmt.py +++ b/scripts/nmt/nmt.py @@ -146,6 +146,7 @@ def request(inputs): if not os.path.exists(args.text_file): print(f"Invalid input file path: {args.text_file}", file=sys.stderr) return EXIT_BAD_INPUT + translated_any = False with open(args.text_file, "r") as f: batch = [] for line in f: @@ -154,9 +155,14 @@ def request(inputs): batch.append(line) if len(batch) == args.batch_size: request(batch) + translated_any = True batch = [] if len(batch) > 0: request(batch) + translated_any = True + if not translated_any: + print(f"{args.text_file} contained no non-empty lines", file=sys.stderr) + return EXIT_BAD_INPUT return if not args.text or not args.text.strip():