Skip to content

Commit a63d3f9

Browse files
SeqIO TeamSeqIO
authored andcommitted
Allows ("targets", "predictions", "inputs") positional arguments for metric_fns
PiperOrigin-RevId: 803170566
1 parent b898b9f commit a63d3f9

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

seqio/dataset_providers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,12 +1227,14 @@ def _all_metric_fns(
12271227
score_fns.append(metric_fn)
12281228
elif pos_args == ("targets", "predictions", "aux_values"):
12291229
predict_with_aux_fns.append(metric_fn)
1230+
elif pos_args == ("targets", "predictions", "inputs"):
1231+
predict_fns.append(metric_fn)
12301232
else:
12311233
raise ValueError(
1232-
"Metric functions must have positional arguments matching either "
1233-
"('targets', 'scores'), ('targets', 'predictions') or "
1234-
"('targets', 'predictions', 'aux_values'). "
1235-
f"Got: {pos_args}"
1234+
"Metric functions must have positional arguments matching either"
1235+
" ('targets', 'scores'), ('targets', 'predictions'), ('targets',"
1236+
" 'predictions', 'aux_values') or ('targets', 'predictions',"
1237+
f" 'inputs').Got: {pos_args}"
12361238
)
12371239
return predict_fns, score_fns, predict_with_aux_fns
12381240

0 commit comments

Comments
 (0)