Skip to content

Ssm has following error when running locally on tip of tree #412

Description

@klei22
Training... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   0% -:--:--
Traceback (most recent call last):
  File "train.py", line 1185, in <module>
    main()
  File "train.py", line 1175, in main
    trainer.train()
  File "train.py", line 990, in train
    self.log_metrics(losses, running_mfu, current_epoch, self.tokens_trained, current_dataset)
  File "train.py", line 803, in log_metrics
    self.export_model_graph()
  File "train.py", line 785, in export_model_graph
    self.writer.add_graph(self.model, (dummy_input, dummy_targets, dummy_iter_num))
  File "writer.py", line 841, in add_graph
    graph(model, input_to_model, verbose, use_strict_trace)
  File "_pytorch_graph.py", line 337, in graph
    raise e
  File "_pytorch_graph.py", line 331, in graph
    trace = torch.jit.trace(model, args, strict=use_strict_trace)
  File "_trace.py", line 1002, in trace
    traced_func = _trace_impl(
  File "_trace.py", line 698, in _trace_impl
    return trace_module(
  File "_trace.py", line 1278, in trace_module
    module._c._create_method_from_trace(
  File "module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "model.py", line 409, in forward
    x = block(x, iter_num)
  File "module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "model.py", line 174, in forward
    return custom_forward(x)
  File "model.py", line 167, in custom_forward
    x = x + self.attn(self.ln_1(x), iter_num)
  File "module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "attention_variations.py", line 506, in forward
    outputs = selective_scan_fn(
  File "selective_scan_interface.py", line 110, in selective_scan_fn
    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
  File "function.py", line 575, in apply
    return super().apply(*args, **kwargs)
  File "selective_scan_interface.py", line 44, in forward
    out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
RuntimeError: Expected B.scalar_type() == (!is_variable_B ? weight_type : input_type) to be true, but got false.  

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions