@@ -130,7 +130,7 @@ def splits(self) -> Sequence[str]:
130130 def get_dataset (
131131 self ,
132132 sequence_length : Optional [Mapping [str , int ]] = None ,
133- split : str = tfds .Split .TRAIN ,
133+ split : str = tfds .Split .TRAIN , # pylint:disable=attribute-error
134134 use_cached : bool = False ,
135135 shuffle : bool = True ,
136136 seed : Optional [int ] = None ,
@@ -173,7 +173,7 @@ def add_provider(cls, name: str, provider):
173173 task_registry_provenance_tracking .maybe_record_provenance (
174174 frame = inspect .currentframe (),
175175 name = name ,
176- provider_type = provider .__class__ .__name__ ,
176+ provider_type = provider .__class__ .__name__ , # pylint:disable=attribute-error
177177 )
178178
179179 @classmethod
@@ -325,7 +325,7 @@ def list_shards(self, split: str) -> Sequence[str]:
325325 @abc .abstractmethod
326326 def get_dataset (
327327 self , # pytype: disable=signature-mismatch # overriding-default-value-checks
328- split : str = tfds .Split .TRAIN ,
328+ split : str = tfds .Split .TRAIN , # pylint:disable=attribute-error
329329 shuffle : bool = True ,
330330 seed : Optional [int ] = None ,
331331 shard_info : Optional [ShardInfo ] = None ,
@@ -432,7 +432,7 @@ def __repr__(self):
432432
433433 def get_dataset (
434434 self ,
435- split : str = tfds .Split .TRAIN ,
435+ split : str = tfds .Split .TRAIN , # pylint:disable=attribute-error
436436 shuffle : bool = True ,
437437 seed : Optional [int ] = None ,
438438 shard_info : Optional [ShardInfo ] = None ,
@@ -550,7 +550,7 @@ def get_dataset(
550550 num_epochs : Optional [int ] = 1 , # Unused
551551 ) -> tf .data .Dataset :
552552 if split is None :
553- split = tfds .Split .TRAIN
553+ split = tfds .Split .TRAIN # pylint:disable=attribute-error
554554 return self .tfds_dataset .load (
555555 split , shuffle_files = shuffle , seed = seed , shard_info = shard_info
556556 )
@@ -639,7 +639,7 @@ def __repr__(self):
639639
640640 def get_dataset (
641641 self ,
642- split : str = tfds .Split .TRAIN ,
642+ split : str = tfds .Split .TRAIN , # pylint:disable=attribute-error
643643 shuffle : bool = True ,
644644 seed : Optional [int ] = None ,
645645 shard_info : Optional [ShardInfo ] = None ,
@@ -694,7 +694,7 @@ def list_shards(self, split: str) -> Sequence[str]:
694694 return _list_files (pattern = filepattern )
695695
696696 if not any (glob .has_magic (f ) for f in filepattern ):
697- return filepattern
697+ return filepattern # pytype: disable=bad-return-type
698698 else :
699699 return _list_files (pattern = filepattern )
700700
@@ -1512,7 +1512,7 @@ def assert_cached(self) -> None:
15121512 ), f"'{ self .name } ' does not exist in any of the task cache directories."
15131513
15141514 def get_cached_stats (
1515- self , split : str = tfds .Split .TRAIN
1515+ self , split : str = tfds .Split .TRAIN # pylint:disable=attribute-error
15161516 ) -> Mapping [str , Union [int , float ]]:
15171517 """Returns basic statistics for cached dataset."""
15181518 self .assert_cached ()
@@ -1526,10 +1526,10 @@ def get_cached_stats(
15261526 self ._stats [split ] = json .load (f )
15271527 return self ._stats [split ]
15281528
1529- def get_dataset (
1529+ def get_dataset ( # pylint: disable=arguments-renamed
15301530 self , # pytype: disable=signature-mismatch # overriding-default-value-checks
15311531 sequence_length : Optional [Mapping [str , int ]] = None ,
1532- split : str = tfds .Split .TRAIN ,
1532+ split : str = tfds .Split .TRAIN , # pylint:disable=attribute-error
15331533 use_cached : bool = False ,
15341534 shuffle : bool = True ,
15351535 shuffle_buffer_size : Optional [int ] = None , # Unique to Task
@@ -1614,7 +1614,7 @@ def get_dataset(
16141614 )
16151615 else :
16161616 ds = source .get_dataset (split = split , shuffle = shuffle , seed = seed )
1617- ds = ds .shard (shard_info .num_shards , shard_info .index )
1617+ ds = ds .shard (shard_info .num_shards , shard_info .index ) # pylint:disable=attribute-error
16181618
16191619 num_shards = shard_info .num_shards if shard_info else 1
16201620 if try_in_mem_cache and (
@@ -1915,7 +1915,7 @@ def get_task_dataset(
19151915 task : Task ,
19161916 output_feature_keys : Set [str ],
19171917 sequence_length : Optional [Mapping [str , int ]] = None ,
1918- split : str = tfds .Split .TRAIN ,
1918+ split : str = tfds .Split .TRAIN , # pylint:disable=attribute-error
19191919 use_cached : bool = False ,
19201920 shuffle : bool = True ,
19211921 seed : Optional [int ] = None ,
@@ -1947,7 +1947,7 @@ def _get_all_mixing_rates(self, tasks):
19471947 def get_dataset ( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
19481948 self ,
19491949 sequence_length : Optional [Mapping [str , int ]] = None ,
1950- split : str = tfds .Split .TRAIN ,
1950+ split : str = tfds .Split .TRAIN , # pylint:disable=attribute-error
19511951 use_cached : bool = False ,
19521952 shuffle : bool = True ,
19531953 seed : Optional [int ] = None ,
@@ -2115,6 +2115,98 @@ def _get_submixture_rate(self, mix: "Mixture") -> float:
21152115 return float (rate )
21162116
21172117
2118+ def get_dataset_iterator_from_tasks (
2119+ tasks : Union [
2120+ Sequence [SubtaskOrName ], Sequence [Tuple [SubtaskOrName , MixtureRate ]]
2121+ ],
2122+ sources : Sequence [grain .TfDataSource ],
2123+ proportions : Sequence [float ],
2124+ shard_info : Optional [ShardInfo ],
2125+ seed : Optional [int ],
2126+ num_epochs : Optional [int ],
2127+ strict_transformations : bool ,
2128+ shuffle : bool ,
2129+ batch_size : Optional [int ],
2130+ sequence_length : Optional [Mapping [str , int ]],
2131+ trim_output_features : bool ,
2132+ output_features : Mapping [str , str ],
2133+ feature_converter : FeatureConverter ,
2134+ ) -> grain .TfGrainDatasetIterator :
2135+ """Returns a deterministic DatasetIterator for the mixture."""
2136+ if shard_info is None :
2137+ shard_options = grain .NoSharding ()
2138+ else :
2139+ shard_options = grain .ShardOptions (
2140+ shard_index = shard_info .index , shard_count = shard_info .num_shards
2141+ )
2142+
2143+ if num_epochs and num_epochs != 1 :
2144+ raise ValueError (
2145+ "Epochs are not supported for mixtures. A mixture "
2146+ "always repeats indefinitely over it's tasks."
2147+ )
2148+
2149+ if sequence_length is not None :
2150+ # Avoid index being dropped. In case of example packing we even need to
2151+ # pack it (but it should never be the limiting factor).
2152+ sequence_length = dict (sequence_length )
2153+ sequence_length [grain .INDEX ] = max (sequence_length .values ())
2154+
2155+ extra_args = {
2156+ "sequence_length" : sequence_length ,
2157+ "output_features" : output_features ,
2158+ }
2159+ add_kwargs = lambda t : utils .add_kwargs_to_transform (t , ** extra_args )
2160+
2161+ transformations_per_source = []
2162+ for task in tasks :
2163+ transformations_per_source .append (
2164+ [add_kwargs (t ) for t in task .preprocessors ] # pytype: disable=attribute-error
2165+ ) # pylint: disable=protected-access
2166+ # Transformations applied after combination all data sources.
2167+ transformations = [
2168+ seqio_preprocessors .ReshapeFeatures ({grain .INDEX : [- 1 ]}),
2169+ seqio_preprocessors .DropFeatures (
2170+ set (grain .META_FEATURES ) - {grain .INDEX }
2171+ ),
2172+ ]
2173+ if trim_output_features :
2174+ transformations .append (seqio_preprocessors ._TrimDataset ()) # pylint: disable=protected-access
2175+ if hasattr (feature_converter , "get_grain_transforms" ):
2176+ transformations += feature_converter .get_grain_transforms (
2177+ batch_size = batch_size , task_feature_lengths = sequence_length
2178+ )
2179+ elif strict_transformations :
2180+ raise NotImplementedError (
2181+ f"FeatureConverter { feature_converter } does "
2182+ "not implement get_grain_transforms()."
2183+ )
2184+ else :
2185+ transformations += [
2186+ functools .partial (
2187+ feature_converter , task_feature_lengths = sequence_length
2188+ )
2189+ ]
2190+ transformations = [add_kwargs (t ) for t in transformations ]
2191+
2192+ sampler = grain .TfMixtureIndexSampler (
2193+ [len (s ) for s in sources ],
2194+ shard_options = shard_options ,
2195+ proportions = proportions ,
2196+ shuffle = shuffle ,
2197+ seed = seed ,
2198+ )
2199+ data_loader = grain .TfMixtureDataLoader (
2200+ sources = sources ,
2201+ sampler = sampler ,
2202+ transformations_per_source = transformations_per_source ,
2203+ transformations = transformations ,
2204+ iterator_options = grain .IteratorOptions (drop_grain_meta_features = True ),
2205+ strict_transformations = strict_transformations ,
2206+ )
2207+ return iter (data_loader ) # pytype: disable=bad-return-type
2208+
2209+
21182210
21192211
21202212def _log_padding_fractions (dataset , sequence_length , num_examples = 100 ):
0 commit comments