Skip to content

Fix torch.compile #1

@lapp0

Description

@lapp0

Reproducer

import distily

distily.run.benchmark(
    teacher_model_name_or_path="gpt2",
    output_dir="distily_verify_compile",
    hub_model_id="distily/distily_verify_compile",
    push_to_hub=True,
    report_to="tensorboard",
    dataset_sample_size=4000,
    gradient_accumulation_steps=1,
    harness_benchmarks=[],
    params=[
        {"torch_compile": True},
        {"torch_compile": False},
    ]
)

Error

Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/opt/conda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1337, in torch_dynamo_resume_in_forward_at_1315
    lm_logits = self.lm_head(hidden_states). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/distily/run.py", line 66, in benchmark
    res = train(*parsed_args_tuple)
  File "/opt/conda/lib/python3.10/site-packages/distily/run.py", line 86, in train
    trainer.train()
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 92, in train
    train_output = super().train(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1929, in train
    return inner_training_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2205, in _inner_training_loop
    self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 2761, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 135, in evaluate
    super().evaluate(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3666, in evaluate
    output = eval_loop(
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 3857, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 4075, in prediction_step
    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
  File "/opt/conda/lib/python3.10/site-packages/distily/distillation_trainer.py", line 103, in compute_loss
    loss_dict = self.distillation_objective(self.teacher_model, model, inputs)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/objectives.py", line 106, in __call__
    logits_loss = self._calc_loss(out_s.logits, out_t.logits, self.logits_loss_component, device)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/objectives.py", line 135, in _calc_loss
    loss = loss_component.get_loss(feat_s, feat_t)
  File "/opt/conda/lib/python3.10/site-packages/distily/objectives/loss.py", line 47, in kl_divergence_loss
    teacher_prob = F.softmax(feat_t, dim=-1)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 1885, in softmax
    ret = input.softmax(dim)
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/opt/conda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1337, in torch_dynamo_resume_in_forward_at_1315
    lm_logits = self.lm_head(hidden_states). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

Implications

Completion of this issue allows us to benchmark and integrate

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions