-
Notifications
You must be signed in to change notification settings - Fork 49
Expand file tree
/
Copy pathriva_streaming_asr_client.py
More file actions
123 lines (113 loc) · 5.15 KB
/
riva_streaming_asr_client.py
File metadata and controls
123 lines (113 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
import argparse
import os
import queue
import time
from pathlib import Path
from threading import Thread
from typing import Union
import riva.client
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from riva.client.asr import get_wav_file_parameters
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Streaming transcription via Riva AI Services. Unlike `scripts/asr/transcribe_file.py` script, "
"this script can perform transcription several times on same audio if `--num-iterations` is "
"greater than 1. If `--num-clients` is greater than 1, then a file will be transcribed independently "
"in several threads. Unlike other ASR scripts, this script does not print output but saves it in files "
"which names follow a format `output_<thread_num>.txt`.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--num-parallel-requests", default=1, type=int, help="Number of client threads.")
parser.add_argument("--num-iterations", default=1, type=int, help="Number of iterations over the file.")
parser.add_argument(
"--input-file", required=True, type=Path, help="Name of the WAV file with LINEAR_PCM encoding to transcribe."
)
parser.add_argument(
"--simulate-realtime",
action='store_true',
help="Option to simulate realtime transcription. Audio fragments are sent to a server at a pace that mimics "
"normal speech.",
)
parser.add_argument("--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds.")
parser.add_argument(
"--interim-results", default=False, action='store_true', help="Print intermediate transcripts",
)
parser = add_connection_argparse_parameters(parser)
parser = add_asr_config_argparse_parameters(
parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True
)
args = parser.parse_args()
if args.max_alternatives < 1:
parser.error("`--max-alternatives` must be greater than or equal to 1")
return args
def streaming_transcription_worker(
args: argparse.Namespace, output_file: Union[str, os.PathLike], thread_i: int, exception_queue: queue.Queue
) -> None:
output_file = Path(output_file).expanduser()
try:
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
language_code=args.language_code,
max_alternatives=args.max_alternatives,
profanity_filter=args.profanity_filter,
enable_automatic_punctuation=args.automatic_punctuation,
verbatim_transcripts=not args.verbatim_transcripts,
enable_word_time_offsets=args.word_time_offsets,
model=args.model_name,
),
interim_results=args.interim_results,
)
riva.client.add_word_boosting_to_config(config, args.boosted_words_file, args.boosted_words_score)
for _ in range(args.num_iterations):
with riva.client.AudioChunkFileIterator(
args.input_file,
args.chunk_duration_ms,
delay_callback=riva.client.sleep_audio_length if args.simulate_realtime else None,
) as audio_chunk_iterator:
riva.client.print_streaming(
responses=asr_service.streaming_response_generator(
audio_chunks=audio_chunk_iterator, streaming_config=config,
),
input_file=args.input_file,
output_file=output_file,
additional_info='time',
word_time_offsets=args.word_time_offsets,
)
except BaseException as e:
exception_queue.put((e, thread_i))
raise
def main() -> None:
args = parse_args()
print("Number of clients:", args.num_parallel_requests)
print("Number of iteration:", args.num_iterations)
print("Input file:", args.input_file)
threads = []
exception_queue = queue.Queue()
for i in range(args.num_parallel_requests):
t = Thread(target=streaming_transcription_worker, args=[args, f"output_{i:d}.txt", i, exception_queue])
t.start()
threads.append(t)
while True:
try:
exc, thread_i = exception_queue.get(block=False)
except queue.Empty:
pass
else:
raise RuntimeError(f"A thread with index {thread_i} failed with error:\n{exc}")
all_dead = True
for t in threads:
t.join(0.0)
if t.is_alive():
all_dead = False
break
if all_dead:
break
time.sleep(0.05)
for i in range(args.num_parallel_requests):
print(f"Thread {i} done, output written to output_{i}.txt")
if __name__ == "__main__":
main()