Skip to content

Commit 9d072c5

Browse files
committed
Cleanup comments
1 parent 565e78c commit 9d072c5

1 file changed

Lines changed: 17 additions & 114 deletions

File tree

examples/models/parakeet/main.cpp

Lines changed: 17 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -47,34 +47,15 @@ using ::executorch::runtime::EValue;
4747

4848
namespace {
4949

50-
// TDT duration values
50+
// TDT duration values (hardcoded for simplicity, comes from model config in NeMo implementation)
51+
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238
5152
const std::vector<int> DURATIONS = {0, 1, 2, 3, 4};
52-
// NeMo: TDT maps a duration-class argmax -> "skip" (advance) in encoder frames.
53-
// - Viable duration choices come from loss/config into decoding_cfg.durations:
54-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238
55-
// - Greedy TDT decoding uses: skip = self.durations[d_k]
56-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2665-L2669
57-
// Divergence: we hardcode {0,1,2,3,4} here (matches current Parakeet TDT
58-
// export). If the exported model's duration set changes, this must be updated
59-
// to match or timestamps/decoding will drift.
6053

6154
struct TokenTimestamp {
6255
int64_t id;
6356
int64_t start_offset; // encoder frame index
6457
int64_t end_offset; // encoder frame index
6558
};
66-
// NeMo: TDT timing is represented on the Hypothesis as two parallel lists:
67-
// - `Hypothesis.timestamp` holds the encoder frame index where each non-blank
68-
// token was emitted.
69-
// - `Hypothesis.token_duration` holds the predicted duration/skip for that token.
70-
// These are later converted to `{start_offset, end_offset}` via
71-
// RNNTDecoding._compute_offsets_tdt().
72-
// - Where NeMo records `timestamp` + `token_duration` during greedy TDT decode:
73-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2604-L2693
74-
// - Where NeMo converts them into offsets:
75-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1128-L1156
76-
// Divergence: this ExecuTorch example stores `{start_offset,end_offset}` directly
77-
// as it decodes. We don't preserve NeMo's intermediate `timestep` list.
7859

7960
struct SubwordTimestamp {
8061
std::string text;
@@ -91,13 +72,6 @@ struct WordTimestamp {
9172
double start_sec;
9273
double end_sec;
9374
};
94-
// NeMo: word-level timestamps are built from per-token offsets with
95-
// `get_words_offsets()` and then converted from offsets->seconds by
96-
// `process_timestamp_outputs()`.
97-
// - Word grouping: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224
98-
// - Offsets->seconds: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479
99-
// Note: NeMo's `timestamp['char']` for Parakeet is per-subword token offsets
100-
// (not true per-character).
10175

10276
struct SegmentTimestamp {
10377
std::string text;
@@ -106,20 +80,8 @@ struct SegmentTimestamp {
10680
double start_sec;
10781
double end_sec;
10882
};
109-
// NeMo: segment-level timestamps are built from word offsets with
110-
// `get_segment_offsets()`.
111-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327
112-
// Divergence: we only segment on terminal punctuation (.,!,?) and do not
113-
// implement NeMo's optional `segment_gap_threshold` behavior.
11483

11584
bool is_ascii_punctuation_only(const std::string& s) {
116-
// NeMo: TDT punctuation timestamp refinement is applied when a punctuation
117-
// token appears long after the previous token; NeMo "pins" punctuation to the
118-
// previous token's end offset.
119-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189
120-
// Divergence: NeMo checks membership in a model-specific `supported_punctuation`
121-
// set (can include non-ASCII). Here we approximate by checking ASCII
122-
// `std::ispunct()` on bytes.
12385
if (s.empty()) {
12486
return false;
12587
}
@@ -132,11 +94,6 @@ bool is_ascii_punctuation_only(const std::string& s) {
13294
}
13395

13496
size_t ltrim_ascii_whitespace(const std::string& s) {
135-
// NeMo: word boundaries for BPE/WPE are detected via tokenizer-type-specific
136-
// logic in `get_words_offsets()` (word delimiter char, special markers, etc).
137-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99
138-
// Divergence: we treat leading *ASCII whitespace* in the decoded piece as the
139-
// only word boundary signal.
14097
size_t i = 0;
14198
while (i < s.size() && std::isspace(static_cast<unsigned char>(s[i]))) {
14299
i++;
@@ -148,12 +105,8 @@ std::vector<SubwordTimestamp> tokens_to_subword_timestamps(
148105
const std::vector<TokenTimestamp>& tokens,
149106
tokenizers::Tokenizer* tokenizer,
150107
double seconds_per_encoder_frame) {
151-
// NeMo reference: TDT per-token "char" timestamps are computed in
152-
// `compute_rnnt_timestamps()` via `_compute_offsets_tdt()` and
153-
// `_refine_timestamps_tdt()`:
154-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991
155-
// NeMo: "char" timestamps for Parakeet-TDT correspond to per-subword token
156-
// offsets, with a TDT punctuation refinement step.
108+
// NeMo reference of TDT per-token "char" timestamp computation:
109+
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991
157110
std::vector<SubwordTimestamp> subwords;
158111
if (!tokenizer) {
159112
return subwords;
@@ -201,14 +154,7 @@ std::vector<WordTimestamp> tokens_to_word_timestamps(
201154
tokenizers::Tokenizer* tokenizer,
202155
double seconds_per_encoder_frame) {
203156
// NeMo reference for word grouping (subword/char offsets -> word offsets):
204-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224
205-
//
206-
// Divergences from NeMo:
207-
// - NeMo builds words from decoded token offsets (and handles tokenizer types);
208-
// here we build words by incrementally decoding each token and using leading
209-
// ASCII whitespace as the boundary.
210-
// - NeMo returns `Hypothesis.timestamp['char']` in addition to word/segment;
211-
// this example also emits per-subword timestamps.
157+
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224
212158
std::vector<WordTimestamp> words;
213159
if (!tokenizer || tokens.empty()) {
214160
return words;
@@ -251,15 +197,14 @@ std::vector<WordTimestamp> tokens_to_word_timestamps(
251197
continue;
252198
}
253199

200+
// TDT sometimes emits punctuation long after preceding token. Thus, pin to previous token.
201+
// NeMo applies the same correction:
202+
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189
203+
// Divergence: NeMo consults `supported_punctuation` from the model; here we
204+
// approximate punctuation detection (ASCII-only) via `is_ascii_punctuation_only()`.
254205
TokenTimestamp adjusted = token_ts;
255206
const bool is_punct = is_ascii_punctuation_only(trimmed_piece);
256207
if (is_punct && has_prev_end_offset) {
257-
// TDT can sometimes emit punctuation long after the preceding word. Pin
258-
// punctuation timing to the previous token end.
259-
// NeMo: RNNTDecoding._refine_timestamps_tdt() applies the same correction:
260-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189
261-
// Divergence: NeMo consults `supported_punctuation` from the model; here we
262-
// approximate punctuation detection (ASCII-only) via `is_ascii_punctuation_only()`.
263208
adjusted.start_offset = prev_end_offset;
264209
adjusted.end_offset = prev_end_offset;
265210
}
@@ -269,11 +214,9 @@ std::vector<WordTimestamp> tokens_to_word_timestamps(
269214
current_start_offset = adjusted.start_offset;
270215
current_end_offset = adjusted.end_offset;
271216
} else if (had_leading_ws && !is_punct) {
272-
// NeMo: `get_words_offsets()` decides when a new word starts using
273-
// tokenizer-aware rules (delimiter markers, WPE "##" prefixes, etc):
274-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99
275-
// Divergence: our boundary rule is strictly "decoded piece had leading
276-
// ASCII whitespace and is not punctuation".
217+
// NeMo builds words from decoded token offsets w/ tokenizer-aware rules:
218+
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99
219+
// Here we simplify, building words per-token and using leading whitespace as the boundary.
277220
emit_word();
278221
current_word = trimmed_piece;
279222
current_start_offset = adjusted.start_offset;
@@ -295,9 +238,7 @@ std::vector<SegmentTimestamp> words_to_segment_timestamps(
295238
const std::vector<WordTimestamp>& words,
296239
double seconds_per_encoder_frame) {
297240
// NeMo reference for segment grouping (word offsets -> segment offsets):
298-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327
299-
// Divergence: we only segment on terminal punctuation (.,!,?) and do not
300-
// implement NeMo's optional `segment_gap_threshold` splitting.
241+
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327
301242
std::vector<SegmentTimestamp> segments;
302243
if (words.empty()) {
303244
return segments;
@@ -338,9 +279,9 @@ std::vector<SegmentTimestamp> words_to_segment_timestamps(
338279

339280
if (!word.text.empty()) {
340281
char last = word.text.back();
282+
// NeMo Divergence: we only segment on terminal punctuation (.,!,?) rather than configurable
283+
// segment_delimiter_tokens. Also no `segment_gap_threshold` splitting.
341284
if (last == '.' || last == '!' || last == '?') {
342-
// NeMo: segment delimiters are configurable (default includes '.', '?', '!'):
343-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L287-L296
344285
emit_segment();
345286
}
346287
}
@@ -359,19 +300,6 @@ std::vector<TokenTimestamp> greedy_decode_executorch(
359300
int64_t num_rnn_layers = 2,
360301
int64_t pred_hidden = 640,
361302
int64_t max_symbols_per_step = 10) {
362-
// NeMo reference for greedy TDT decoding (where the *token timing* originates):
363-
// - Core greedy TDT loop (token argmax + duration argmax + skip advance):
364-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2627-L2717
365-
// - NeMo records per-token `timestamp` + `token_duration` here:
366-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2684-L2693
367-
//
368-
// Divergences from NeMo:
369-
// - We implement a single-item (B=1) decoder loop directly over ExecuTorch
370-
// exported methods (`joint_project_*`, `decoder_predict`, `joint`).
371-
// - We take argmax directly on raw logits (no log_softmax); this matches
372-
// NeMo's argmax choice but we do not compute scores/confidence.
373-
// - NeMo's loop structure uses an explicit inner loop for `skip==0` label
374-
// looping; here we emulate it with `dur==0` and `symbols_on_frame`.
375303
std::vector<TokenTimestamp> hypothesis;
376304
int64_t num_token_classes = vocab_size + 1;
377305

@@ -524,20 +452,9 @@ std::vector<TokenTimestamp> greedy_decode_executorch(
524452
int64_t dur = DURATIONS[dur_idx];
525453

526454
if (k == blank_id) {
527-
// NeMo: if blank is emitted with duration=0, it forces progress to avoid
528-
// infinite loops (skip==0 -> skip=1):
529-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2700-L2704
530-
// Divergence: NeMo advances `time_idx += skip` first and patches `skip`
531-
// after the inner loop; here we apply `max(dur,1)` immediately in the
532-
// blank branch.
533455
t += std::max(dur, (int64_t)1);
534456
symbols_on_frame = 0;
535457
} else {
536-
// NeMo: emits token at `time_idx` and stores duration separately:
537-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2684-L2693
538-
// NeMo later converts (timestamp, token_duration) -> (start_offset, end_offset):
539-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1128-L1156
540-
// Divergence: we store (start_offset=t, end_offset=t+dur) directly.
541458
hypothesis.push_back(TokenTimestamp{k, t, t + dur});
542459

543460
// Update decoder state
@@ -584,12 +501,6 @@ std::vector<TokenTimestamp> greedy_decode_executorch(
584501
t += dur;
585502

586503
if (dur == 0) {
587-
// NeMo: label looping occurs when `skip == 0` (stay on same encoder frame)
588-
// until a non-zero skip is predicted, capped by `max_symbols_per_step`:
589-
// - need_loop = (skip == 0):
590-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2695-L2699
591-
// - force progress after max symbols:
592-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2715-L2716
593504
symbols_on_frame++;
594505
if (symbols_on_frame >= max_symbols_per_step) {
595506
t++;
@@ -776,15 +687,6 @@ int main(int argc, char** argv) {
776687
std::cout << "Transcription tokens: " << text << std::endl;
777688

778689
if (FLAGS_timestamps) {
779-
// NeMo: offset->seconds conversion uses
780-
// start = start_offset * window_stride * subsampling_factor
781-
// end = end_offset * window_stride * subsampling_factor
782-
// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479
783-
//
784-
// Divergence: NeMo reads `window_stride` from the preprocessor config and
785-
// `subsampling_factor` from the encoder module. In ExecuTorch we require
786-
// these values to be exported as `constant_methods` (`window_stride` and
787-
// `encoder_subsampling_factor`). If unavailable, we print raw offsets.
788690
std::vector<::executorch::runtime::EValue> empty_inputs;
789691
auto window_stride_result = model->execute("window_stride", empty_inputs);
790692
auto subsampling_factor_result =
@@ -806,6 +708,7 @@ int main(int argc, char** argv) {
806708
ET_LOG(
807709
Error,
808710
"Timestamps requested but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor.");
711+
return 1;
809712
}
810713

811714
auto words = tokens_to_word_timestamps(

0 commit comments

Comments
 (0)