Skip to content

Commit 71e47ac

Browse files
gauravmishraSeqIO
authored andcommitted
internal
PiperOrigin-RevId: 468794403
1 parent e5c51c5 commit 71e47ac

4 files changed

Lines changed: 4 additions & 65 deletions

File tree

seqio/beam_utils.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -359,11 +359,8 @@ class GetStats(beam.PTransform):
359359
prefixed by the identifiers.
360360
"""
361361

362-
def __init__(self,
363-
output_features: Mapping[str, seqio.Feature],
364-
task_ids: Optional[Mapping[str, int]] = None):
362+
def __init__(self, output_features: Mapping[str, seqio.Feature]):
365363
self._output_features = output_features
366-
self._task_ids = task_ids or {}
367364

368365
def expand(self, pcoll):
369366
example_counts = (
@@ -399,15 +396,6 @@ def _merge_dicts(dicts):
399396
merged_dict.update(d)
400397
return merged_dict
401398

402-
stats = [example_counts, total_tokens, max_tokens, char_length]
403-
if self._task_ids:
404-
task_ids_dict = {"task_ids": self._task_ids}
405-
task_ids = (
406-
pcoll
407-
| "sample_for_task_ids" >> beam.combiners.Sample.FixedSizeGlobally(1)
408-
| "create_task_ids" >> beam.Map(lambda _: task_ids_dict))
409-
stats.append(task_ids)
410-
411-
return (stats
399+
return ([example_counts, total_tokens, max_tokens, char_length]
412400
| "flatten_counts" >> beam.Flatten()
413401
| "merge_stats" >> beam.CombineGlobally(_merge_dicts))

seqio/beam_utils_test.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -222,49 +222,6 @@ def test_get_stats_tokenized_dataset(self):
222222
"targets_chars": 12,
223223
}]))
224224

225-
def test_get_stats_task_ids(self):
226-
# These examples are assumed to be decoded by
227-
# `seqio.test_utils.sentencepiece_vocab()`.
228-
input_examples = [{
229-
# Decoded as "ea", i.e., length 2 string
230-
"inputs": np.array([4, 5]),
231-
# Decoded as "ea test", i.e., length 7 string
232-
"targets": np.array([4, 5, 10]),
233-
}, {
234-
# Decoded as "e", i.e., length 1 string
235-
"inputs": np.array([4]),
236-
# Decoded as "asoil", i.e., length 5 string. "1" is an EOS id.
237-
"targets": np.array([5, 6, 7, 8, 9, 1])
238-
}]
239-
240-
output_features = seqio.test_utils.FakeTaskTest.DEFAULT_OUTPUT_FEATURES
241-
with TestPipeline() as p:
242-
pcoll = (
243-
p
244-
| beam.Create(input_examples)
245-
| beam_utils.GetStats(
246-
output_features=output_features,
247-
task_ids={
248-
"task_name_1": 1,
249-
"task_name_2": 2
250-
}))
251-
252-
util.assert_that(
253-
pcoll,
254-
util.equal_to([{
255-
"inputs_tokens": 3, # 4 and 3 from the first and second exmaples.
256-
"targets_tokens": 8,
257-
"inputs_max_tokens": 2,
258-
"targets_max_tokens": 5,
259-
"examples": 2,
260-
"inputs_chars": 3,
261-
"targets_chars": 12,
262-
"task_ids": {
263-
"task_name_1": 1,
264-
"task_name_2": 2
265-
}
266-
}]))
267-
268225
def test_count_characters_tokenized_dataset(self):
269226
# These examples are assumed to be decoded by
270227
# `seqio.test_utils.sentencepiece_vocab()`.

seqio/test_data/cached_task_with_provenance/stats.train.json

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,5 @@
55
"inputs_tokens": 36,
66
"targets_chars": 29,
77
"targets_max_tokens": 6,
8-
"targets_tokens": 18,
9-
"task_ids": {
10-
"task_name": 1
11-
}
8+
"targets_tokens": 18
129
}

seqio/test_data/cached_task_with_provenance/stats.validation.json

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,5 @@
55
"inputs_tokens": 23,
66
"targets_chars": 37,
77
"targets_max_tokens": 21,
8-
"targets_tokens": 36,
9-
"task_ids": {
10-
"task_name": 1
11-
}
8+
"targets_tokens": 36
129
}

0 commit comments

Comments
 (0)