@@ -47,34 +47,15 @@ using ::executorch::runtime::EValue;
4747
4848namespace {
4949
50- // TDT duration values
50+ // TDT duration values (hardcoded for simplicity, comes from model config in NeMo implementation)
51+ // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238
5152const std::vector<int > DURATIONS = {0 , 1 , 2 , 3 , 4 };
52- // NeMo: TDT maps a duration-class argmax -> "skip" (advance) in encoder frames.
53- // - Viable duration choices come from loss/config into decoding_cfg.durations:
54- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238
55- // - Greedy TDT decoding uses: skip = self.durations[d_k]
56- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2665-L2669
57- // Divergence: we hardcode {0,1,2,3,4} here (matches current Parakeet TDT
58- // export). If the exported model's duration set changes, this must be updated
59- // to match or timestamps/decoding will drift.
6053
6154struct TokenTimestamp {
6255 int64_t id;
6356 int64_t start_offset; // encoder frame index
6457 int64_t end_offset; // encoder frame index
6558};
66- // NeMo: TDT timing is represented on the Hypothesis as two parallel lists:
67- // - `Hypothesis.timestamp` holds the encoder frame index where each non-blank
68- // token was emitted.
69- // - `Hypothesis.token_duration` holds the predicted duration/skip for that token.
70- // These are later converted to `{start_offset, end_offset}` via
71- // RNNTDecoding._compute_offsets_tdt().
72- // - Where NeMo records `timestamp` + `token_duration` during greedy TDT decode:
73- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2604-L2693
74- // - Where NeMo converts them into offsets:
75- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1128-L1156
76- // Divergence: this ExecuTorch example stores `{start_offset,end_offset}` directly
77- // as it decodes. We don't preserve NeMo's intermediate `timestep` list.
7859
7960struct SubwordTimestamp {
8061 std::string text;
@@ -91,13 +72,6 @@ struct WordTimestamp {
9172 double start_sec;
9273 double end_sec;
9374};
94- // NeMo: word-level timestamps are built from per-token offsets with
95- // `get_words_offsets()` and then converted from offsets->seconds by
96- // `process_timestamp_outputs()`.
97- // - Word grouping: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224
98- // - Offsets->seconds: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479
99- // Note: NeMo's `timestamp['char']` for Parakeet is per-subword token offsets
100- // (not true per-character).
10175
10276struct SegmentTimestamp {
10377 std::string text;
@@ -106,20 +80,8 @@ struct SegmentTimestamp {
10680 double start_sec;
10781 double end_sec;
10882};
109- // NeMo: segment-level timestamps are built from word offsets with
110- // `get_segment_offsets()`.
111- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327
112- // Divergence: we only segment on terminal punctuation (.,!,?) and do not
113- // implement NeMo's optional `segment_gap_threshold` behavior.
11483
11584bool is_ascii_punctuation_only (const std::string& s) {
116- // NeMo: TDT punctuation timestamp refinement is applied when a punctuation
117- // token appears long after the previous token; NeMo "pins" punctuation to the
118- // previous token's end offset.
119- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189
120- // Divergence: NeMo checks membership in a model-specific `supported_punctuation`
121- // set (can include non-ASCII). Here we approximate by checking ASCII
122- // `std::ispunct()` on bytes.
12385 if (s.empty ()) {
12486 return false ;
12587 }
@@ -132,11 +94,6 @@ bool is_ascii_punctuation_only(const std::string& s) {
13294}
13395
13496size_t ltrim_ascii_whitespace (const std::string& s) {
135- // NeMo: word boundaries for BPE/WPE are detected via tokenizer-type-specific
136- // logic in `get_words_offsets()` (word delimiter char, special markers, etc).
137- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99
138- // Divergence: we treat leading *ASCII whitespace* in the decoded piece as the
139- // only word boundary signal.
14097 size_t i = 0 ;
14198 while (i < s.size () && std::isspace (static_cast <unsigned char >(s[i]))) {
14299 i++;
@@ -148,12 +105,8 @@ std::vector<SubwordTimestamp> tokens_to_subword_timestamps(
148105 const std::vector<TokenTimestamp>& tokens,
149106 tokenizers::Tokenizer* tokenizer,
150107 double seconds_per_encoder_frame) {
151- // NeMo reference: TDT per-token "char" timestamps are computed in
152- // `compute_rnnt_timestamps()` via `_compute_offsets_tdt()` and
153- // `_refine_timestamps_tdt()`:
154- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991
155- // NeMo: "char" timestamps for Parakeet-TDT correspond to per-subword token
156- // offsets, with a TDT punctuation refinement step.
108+ // NeMo reference of TDT per-token "char" timestamp computation:
109+ // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991
157110 std::vector<SubwordTimestamp> subwords;
158111 if (!tokenizer) {
159112 return subwords;
@@ -201,14 +154,7 @@ std::vector<WordTimestamp> tokens_to_word_timestamps(
201154 tokenizers::Tokenizer* tokenizer,
202155 double seconds_per_encoder_frame) {
203156 // NeMo reference for word grouping (subword/char offsets -> word offsets):
204- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224
205- //
206- // Divergences from NeMo:
207- // - NeMo builds words from decoded token offsets (and handles tokenizer types);
208- // here we build words by incrementally decoding each token and using leading
209- // ASCII whitespace as the boundary.
210- // - NeMo returns `Hypothesis.timestamp['char']` in addition to word/segment;
211- // this example also emits per-subword timestamps.
157+ // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224
212158 std::vector<WordTimestamp> words;
213159 if (!tokenizer || tokens.empty ()) {
214160 return words;
@@ -251,15 +197,14 @@ std::vector<WordTimestamp> tokens_to_word_timestamps(
251197 continue ;
252198 }
253199
200+ // TDT sometimes emits punctuation long after preceding token. Thus, pin to previous token.
201+ // NeMo applies the same correction:
202+ // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189
203+ // Divergence: NeMo consults `supported_punctuation` from the model; here we
204+ // approximate punctuation detection (ASCII-only) via `is_ascii_punctuation_only()`.
254205 TokenTimestamp adjusted = token_ts;
255206 const bool is_punct = is_ascii_punctuation_only (trimmed_piece);
256207 if (is_punct && has_prev_end_offset) {
257- // TDT can sometimes emit punctuation long after the preceding word. Pin
258- // punctuation timing to the previous token end.
259- // NeMo: RNNTDecoding._refine_timestamps_tdt() applies the same correction:
260- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189
261- // Divergence: NeMo consults `supported_punctuation` from the model; here we
262- // approximate punctuation detection (ASCII-only) via `is_ascii_punctuation_only()`.
263208 adjusted.start_offset = prev_end_offset;
264209 adjusted.end_offset = prev_end_offset;
265210 }
@@ -269,11 +214,9 @@ std::vector<WordTimestamp> tokens_to_word_timestamps(
269214 current_start_offset = adjusted.start_offset ;
270215 current_end_offset = adjusted.end_offset ;
271216 } else if (had_leading_ws && !is_punct) {
272- // NeMo: `get_words_offsets()` decides when a new word starts using
273- // tokenizer-aware rules (delimiter markers, WPE "##" prefixes, etc):
274- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99
275- // Divergence: our boundary rule is strictly "decoded piece had leading
276- // ASCII whitespace and is not punctuation".
217+ // NeMo builds words from decoded token offsets w/ tokenizer-aware rules:
218+ // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99
219+ // Here we simplify, building words per-token and using leading whitespace as the boundary.
277220 emit_word ();
278221 current_word = trimmed_piece;
279222 current_start_offset = adjusted.start_offset ;
@@ -295,9 +238,7 @@ std::vector<SegmentTimestamp> words_to_segment_timestamps(
295238 const std::vector<WordTimestamp>& words,
296239 double seconds_per_encoder_frame) {
297240 // NeMo reference for segment grouping (word offsets -> segment offsets):
298- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327
299- // Divergence: we only segment on terminal punctuation (.,!,?) and do not
300- // implement NeMo's optional `segment_gap_threshold` splitting.
241+ // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327
301242 std::vector<SegmentTimestamp> segments;
302243 if (words.empty ()) {
303244 return segments;
@@ -338,9 +279,9 @@ std::vector<SegmentTimestamp> words_to_segment_timestamps(
338279
339280 if (!word.text .empty ()) {
340281 char last = word.text .back ();
282+ // NeMo Divergence: we only segment on terminal punctuation (.,!,?) rather than configurable
283+ // segment_delimiter_tokens. Also no `segment_gap_threshold` splitting.
341284 if (last == ' .' || last == ' !' || last == ' ?' ) {
342- // NeMo: segment delimiters are configurable (default includes '.', '?', '!'):
343- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L287-L296
344285 emit_segment ();
345286 }
346287 }
@@ -359,19 +300,6 @@ std::vector<TokenTimestamp> greedy_decode_executorch(
359300 int64_t num_rnn_layers = 2 ,
360301 int64_t pred_hidden = 640 ,
361302 int64_t max_symbols_per_step = 10 ) {
362- // NeMo reference for greedy TDT decoding (where the *token timing* originates):
363- // - Core greedy TDT loop (token argmax + duration argmax + skip advance):
364- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2627-L2717
365- // - NeMo records per-token `timestamp` + `token_duration` here:
366- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2684-L2693
367- //
368- // Divergences from NeMo:
369- // - We implement a single-item (B=1) decoder loop directly over ExecuTorch
370- // exported methods (`joint_project_*`, `decoder_predict`, `joint`).
371- // - We take argmax directly on raw logits (no log_softmax); this matches
372- // NeMo's argmax choice but we do not compute scores/confidence.
373- // - NeMo's loop structure uses an explicit inner loop for `skip==0` label
374- // looping; here we emulate it with `dur==0` and `symbols_on_frame`.
375303 std::vector<TokenTimestamp> hypothesis;
376304 int64_t num_token_classes = vocab_size + 1 ;
377305
@@ -524,20 +452,9 @@ std::vector<TokenTimestamp> greedy_decode_executorch(
524452 int64_t dur = DURATIONS[dur_idx];
525453
526454 if (k == blank_id) {
527- // NeMo: if blank is emitted with duration=0, it forces progress to avoid
528- // infinite loops (skip==0 -> skip=1):
529- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2700-L2704
530- // Divergence: NeMo advances `time_idx += skip` first and patches `skip`
531- // after the inner loop; here we apply `max(dur,1)` immediately in the
532- // blank branch.
533455 t += std::max (dur, (int64_t )1 );
534456 symbols_on_frame = 0 ;
535457 } else {
536- // NeMo: emits token at `time_idx` and stores duration separately:
537- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2684-L2693
538- // NeMo later converts (timestamp, token_duration) -> (start_offset, end_offset):
539- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1128-L1156
540- // Divergence: we store (start_offset=t, end_offset=t+dur) directly.
541458 hypothesis.push_back (TokenTimestamp{k, t, t + dur});
542459
543460 // Update decoder state
@@ -584,12 +501,6 @@ std::vector<TokenTimestamp> greedy_decode_executorch(
584501 t += dur;
585502
586503 if (dur == 0 ) {
587- // NeMo: label looping occurs when `skip == 0` (stay on same encoder frame)
588- // until a non-zero skip is predicted, capped by `max_symbols_per_step`:
589- // - need_loop = (skip == 0):
590- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2695-L2699
591- // - force progress after max symbols:
592- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2715-L2716
593504 symbols_on_frame++;
594505 if (symbols_on_frame >= max_symbols_per_step) {
595506 t++;
@@ -776,15 +687,6 @@ int main(int argc, char** argv) {
776687 std::cout << " Transcription tokens: " << text << std::endl;
777688
778689 if (FLAGS_timestamps) {
779- // NeMo: offset->seconds conversion uses
780- // start = start_offset * window_stride * subsampling_factor
781- // end = end_offset * window_stride * subsampling_factor
782- // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479
783- //
784- // Divergence: NeMo reads `window_stride` from the preprocessor config and
785- // `subsampling_factor` from the encoder module. In ExecuTorch we require
786- // these values to be exported as `constant_methods` (`window_stride` and
787- // `encoder_subsampling_factor`). If unavailable, we print raw offsets.
788690 std::vector<::executorch::runtime::EValue> empty_inputs;
789691 auto window_stride_result = model->execute (" window_stride" , empty_inputs);
790692 auto subsampling_factor_result =
@@ -806,6 +708,7 @@ int main(int argc, char** argv) {
806708 ET_LOG (
807709 Error,
808710 " Timestamps requested but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor." );
711+ return 1 ;
809712 }
810713
811714 auto words = tokens_to_word_timestamps (
0 commit comments