diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 3fe2cd485..d090ffd54 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -590,14 +590,14 @@ def run(config, recorder, diagnostic_config): diagnostics_context, maybe_record_goodput(recorder, GoodputEvent.JOB), max_utils.maybe_get_transformer_engine_context(config), - maybe_monitor_goodput(config), ): train_loop(config, recorder) def main(argv: Sequence[str]) -> None: config, recorder, diagnostic_config = initialize(argv) - run(config, recorder, diagnostic_config) + with maybe_monitor_goodput(config): + run(config, recorder, diagnostic_config) if __name__ == "__main__":