Skip to content

Commit bd02658

Browse files
texasmichelleSeqIO
authored andcommitted
Internal change.
PiperOrigin-RevId: 590631662
1 parent 8012bf6 commit bd02658

3 files changed

Lines changed: 9 additions & 4 deletions

File tree

seqio/dataset_providers_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,7 @@ def test_tasks(self):
11781178
self.add_task("task2", self.function_source)
11791179
MixtureRegistry.add("test_mix1", [("task1", 1), ("task2", 1)])
11801180
mix = MixtureRegistry.get("test_mix1")
1181-
self.assertEqual(len(mix.tasks), 2)
1181+
self.assertLen(mix.tasks, 2)
11821182

11831183
for task in mix.tasks:
11841184
self.verify_task_matches_fake_datasets(task.name, use_cached=False)
@@ -1200,7 +1200,7 @@ def test_task_objs(self):
12001200

12011201
MixtureRegistry.add("test_mix1", [(task1, 1), (task2, 1)])
12021202
mix = MixtureRegistry.get("test_mix1")
1203-
self.assertEqual(len(mix.tasks), 2)
1203+
self.assertLen(mix.tasks, 2)
12041204

12051205
for task in mix.tasks:
12061206
self.verify_task_matches_fake_datasets(task=task, use_cached=False)
@@ -1221,7 +1221,7 @@ def test_task_objs_default_rate(self):
12211221
)
12221222
MixtureRegistry.add("test_mix1", [task1, task2], default_rate=1.0)
12231223
mix = MixtureRegistry.get("test_mix1")
1224-
self.assertEqual(len(mix.tasks), 2)
1224+
self.assertLen(mix.tasks, 2)
12251225

12261226
for task in mix.tasks:
12271227
self.verify_task_matches_fake_datasets(task=task, use_cached=False)
@@ -1250,7 +1250,7 @@ def test_tasks_with_tunable_rates(self):
12501250
)
12511251

12521252
mix = MixtureRegistry.get("test_mix2")
1253-
self.assertEqual(len(mix.tasks), 3)
1253+
self.assertLen(mix.tasks, 3)
12541254

12551255
automl_context = pg.hyper.DynamicEvaluationContext(require_hyper_name=True)
12561256
with automl_context.collect():

seqio/experimental.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Experimental utilities for SeqIO."""
16+
1617
import functools
1718
import inspect
1819
from typing import Callable, Iterable, Mapping, Optional, Sequence
@@ -86,6 +87,7 @@ def _no_op_mixture_registry_get(*args, **kwargs):
8687

8788

8889
def disable_registry():
90+
"""Disables the seqio TaskRegistry and MixtureRegistry."""
8991
_enfore_empty_registries()
9092
dataset_providers.TaskRegistry.add = _no_op_task_registry_add
9193
dataset_providers.TaskRegistry.add_provider = _no_op_task_registry_add

seqio/experimental_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Tests for seqio.preprocessors."""
1616

1717
import contextlib
18+
from unittest import mock
19+
1820
from absl.testing import absltest
1921
from seqio import dataset_providers
2022
from seqio import experimental
@@ -984,5 +986,6 @@ def test_mixture_registry_get_error(self):
984986
MixtureRegistry.get('dummy_mixture')
985987

986988

989+
987990
if __name__ == '__main__':
988991
absltest.main()

0 commit comments

Comments
 (0)