Skip to content

Mask shape mismatch due to multiple workers #114

@sleepymalc

Description

@sleepymalc

When I tried to run the HF example (specifically, gpt_log.py), I encountered the following problem:

Traceback (most recent call last):
  File "/LoGra/gpt_log.py", line 66, in <module>
    main()
  File "/LoGra/gpt_log.py", line 62, in main
    trainer.extract_log()
  File "/LoGra/_logix/huggingface/patch.py", line 99, in extract_log
    self.train(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/LoGra/_logix/huggingface/patch.py", line 173, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/transformers/trainer.py", line 3161, in compute_loss
    outputs = model(**inputs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 193, in forward
    outputs = self.parallel_apply(replicas, inputs, module_kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 212, in parallel_apply
    return parallel_apply(
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 126, in parallel_apply
    output.reraise()
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/_utils.py", line 733, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1305, in forward
    transformer_outputs = self.transformer(
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1119, in forward
    outputs = block(
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 617, in forward
    attn_outputs = self.attn(
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 328, in forward
    query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/LoGra/_logix/lora/modules.py", line 50, in forward
    result += self.logix_lora_C(self.logix_lora_B(self.logix_lora_A(input)))
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
  File "/home/pbb/miniconda3/envs/LoGra/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1782, in inner
    args_result = hook(self, args)
  File "/LoGra/_logix/logging/logger.py", line 126, in _forward_hook_fn
    activations = activations * mask.unsqueeze(-1)
RuntimeError: The size of tensor a (32) must match the size of tensor b (128) at non-singleton dimension 0

It seems like the mismatch is due to the fact that all the 4 GPUs on the server are in use, and somehow the mask is aggregating all masks across different devices. This is just a guess.

The question is,

  1. How to specify the number of workers when running this.
  2. If the above can't be done and/or the above doesn't solve the problem, what's the real cause of the above error?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions