Skip to content

Commit 4b4559b

Browse files
SeqIO TeamSeqIO
authored andcommitted
task registration for fewshot binding
PiperOrigin-RevId: 586826864
1 parent 515d917 commit 4b4559b

1 file changed

Lines changed: 105 additions & 13 deletions

File tree

seqio/dataset_providers.py

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def splits(self) -> Sequence[str]:
130130
def get_dataset(
131131
self,
132132
sequence_length: Optional[Mapping[str, int]] = None,
133-
split: str = tfds.Split.TRAIN,
133+
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
134134
use_cached: bool = False,
135135
shuffle: bool = True,
136136
seed: Optional[int] = None,
@@ -173,7 +173,7 @@ def add_provider(cls, name: str, provider):
173173
task_registry_provenance_tracking.maybe_record_provenance(
174174
frame=inspect.currentframe(),
175175
name=name,
176-
provider_type=provider.__class__.__name__,
176+
provider_type=provider.__class__.__name__, # pylint:disable=attribute-error
177177
)
178178

179179
@classmethod
@@ -325,7 +325,7 @@ def list_shards(self, split: str) -> Sequence[str]:
325325
@abc.abstractmethod
326326
def get_dataset(
327327
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
328-
split: str = tfds.Split.TRAIN,
328+
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
329329
shuffle: bool = True,
330330
seed: Optional[int] = None,
331331
shard_info: Optional[ShardInfo] = None,
@@ -432,7 +432,7 @@ def __repr__(self):
432432

433433
def get_dataset(
434434
self,
435-
split: str = tfds.Split.TRAIN,
435+
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
436436
shuffle: bool = True,
437437
seed: Optional[int] = None,
438438
shard_info: Optional[ShardInfo] = None,
@@ -550,7 +550,7 @@ def get_dataset(
550550
num_epochs: Optional[int] = 1, # Unused
551551
) -> tf.data.Dataset:
552552
if split is None:
553-
split = tfds.Split.TRAIN
553+
split = tfds.Split.TRAIN # pylint:disable=attribute-error
554554
return self.tfds_dataset.load(
555555
split, shuffle_files=shuffle, seed=seed, shard_info=shard_info
556556
)
@@ -639,7 +639,7 @@ def __repr__(self):
639639

640640
def get_dataset(
641641
self,
642-
split: str = tfds.Split.TRAIN,
642+
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
643643
shuffle: bool = True,
644644
seed: Optional[int] = None,
645645
shard_info: Optional[ShardInfo] = None,
@@ -694,7 +694,7 @@ def list_shards(self, split: str) -> Sequence[str]:
694694
return _list_files(pattern=filepattern)
695695

696696
if not any(glob.has_magic(f) for f in filepattern):
697-
return filepattern
697+
return filepattern # pytype: disable=bad-return-type
698698
else:
699699
return _list_files(pattern=filepattern)
700700

@@ -1512,7 +1512,7 @@ def assert_cached(self) -> None:
15121512
), f"'{self.name}' does not exist in any of the task cache directories."
15131513

15141514
def get_cached_stats(
1515-
self, split: str = tfds.Split.TRAIN
1515+
self, split: str = tfds.Split.TRAIN # pylint:disable=attribute-error
15161516
) -> Mapping[str, Union[int, float]]:
15171517
"""Returns basic statistics for cached dataset."""
15181518
self.assert_cached()
@@ -1526,10 +1526,10 @@ def get_cached_stats(
15261526
self._stats[split] = json.load(f)
15271527
return self._stats[split]
15281528

1529-
def get_dataset(
1529+
def get_dataset( # pylint: disable=arguments-renamed
15301530
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
15311531
sequence_length: Optional[Mapping[str, int]] = None,
1532-
split: str = tfds.Split.TRAIN,
1532+
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
15331533
use_cached: bool = False,
15341534
shuffle: bool = True,
15351535
shuffle_buffer_size: Optional[int] = None, # Unique to Task
@@ -1614,7 +1614,7 @@ def get_dataset(
16141614
)
16151615
else:
16161616
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
1617-
ds = ds.shard(shard_info.num_shards, shard_info.index)
1617+
ds = ds.shard(shard_info.num_shards, shard_info.index) # pylint:disable=attribute-error
16181618

16191619
num_shards = shard_info.num_shards if shard_info else 1
16201620
if try_in_mem_cache and (
@@ -1915,7 +1915,7 @@ def get_task_dataset(
19151915
task: Task,
19161916
output_feature_keys: Set[str],
19171917
sequence_length: Optional[Mapping[str, int]] = None,
1918-
split: str = tfds.Split.TRAIN,
1918+
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
19191919
use_cached: bool = False,
19201920
shuffle: bool = True,
19211921
seed: Optional[int] = None,
@@ -1947,7 +1947,7 @@ def _get_all_mixing_rates(self, tasks):
19471947
def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
19481948
self,
19491949
sequence_length: Optional[Mapping[str, int]] = None,
1950-
split: str = tfds.Split.TRAIN,
1950+
split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error
19511951
use_cached: bool = False,
19521952
shuffle: bool = True,
19531953
seed: Optional[int] = None,
@@ -2115,6 +2115,98 @@ def _get_submixture_rate(self, mix: "Mixture") -> float:
21152115
return float(rate)
21162116

21172117

2118+
def get_dataset_iterator_from_tasks(
2119+
tasks: Union[
2120+
Sequence[SubtaskOrName], Sequence[Tuple[SubtaskOrName, MixtureRate]]
2121+
],
2122+
sources: Sequence[grain.TfDataSource],
2123+
proportions: Sequence[float],
2124+
shard_info: Optional[ShardInfo],
2125+
seed: Optional[int],
2126+
num_epochs: Optional[int],
2127+
strict_transformations: bool,
2128+
shuffle: bool,
2129+
batch_size: Optional[int],
2130+
sequence_length: Optional[Mapping[str, int]],
2131+
trim_output_features: bool,
2132+
output_features: Mapping[str, str],
2133+
feature_converter: FeatureConverter,
2134+
) -> grain.TfGrainDatasetIterator:
2135+
"""Returns a deterministic DatasetIterator for the mixture."""
2136+
if shard_info is None:
2137+
shard_options = grain.NoSharding()
2138+
else:
2139+
shard_options = grain.ShardOptions(
2140+
shard_index=shard_info.index, shard_count=shard_info.num_shards
2141+
)
2142+
2143+
if num_epochs and num_epochs != 1:
2144+
raise ValueError(
2145+
"Epochs are not supported for mixtures. A mixture "
2146+
"always repeats indefinitely over it's tasks."
2147+
)
2148+
2149+
if sequence_length is not None:
2150+
# Avoid index being dropped. In case of example packing we even need to
2151+
# pack it (but it should never be the limiting factor).
2152+
sequence_length = dict(sequence_length)
2153+
sequence_length[grain.INDEX] = max(sequence_length.values())
2154+
2155+
extra_args = {
2156+
"sequence_length": sequence_length,
2157+
"output_features": output_features,
2158+
}
2159+
add_kwargs = lambda t: utils.add_kwargs_to_transform(t, **extra_args)
2160+
2161+
transformations_per_source = []
2162+
for task in tasks:
2163+
transformations_per_source.append(
2164+
[add_kwargs(t) for t in task.preprocessors] # pytype: disable=attribute-error
2165+
) # pylint: disable=protected-access
2166+
# Transformations applied after combination all data sources.
2167+
transformations = [
2168+
seqio_preprocessors.ReshapeFeatures({grain.INDEX: [-1]}),
2169+
seqio_preprocessors.DropFeatures(
2170+
set(grain.META_FEATURES) - {grain.INDEX}
2171+
),
2172+
]
2173+
if trim_output_features:
2174+
transformations.append(seqio_preprocessors._TrimDataset()) # pylint: disable=protected-access
2175+
if hasattr(feature_converter, "get_grain_transforms"):
2176+
transformations += feature_converter.get_grain_transforms(
2177+
batch_size=batch_size, task_feature_lengths=sequence_length
2178+
)
2179+
elif strict_transformations:
2180+
raise NotImplementedError(
2181+
f"FeatureConverter {feature_converter} does "
2182+
"not implement get_grain_transforms()."
2183+
)
2184+
else:
2185+
transformations += [
2186+
functools.partial(
2187+
feature_converter, task_feature_lengths=sequence_length
2188+
)
2189+
]
2190+
transformations = [add_kwargs(t) for t in transformations]
2191+
2192+
sampler = grain.TfMixtureIndexSampler(
2193+
[len(s) for s in sources],
2194+
shard_options=shard_options,
2195+
proportions=proportions,
2196+
shuffle=shuffle,
2197+
seed=seed,
2198+
)
2199+
data_loader = grain.TfMixtureDataLoader(
2200+
sources=sources,
2201+
sampler=sampler,
2202+
transformations_per_source=transformations_per_source,
2203+
transformations=transformations,
2204+
iterator_options=grain.IteratorOptions(drop_grain_meta_features=True),
2205+
strict_transformations=strict_transformations,
2206+
)
2207+
return iter(data_loader) # pytype: disable=bad-return-type
2208+
2209+
21182210

21192211

21202212
def _log_padding_fractions(dataset, sequence_length, num_examples=100):

0 commit comments

Comments
 (0)