Skip to content

[Bug][Frontend][ONNX] Gelu operator misses support for 'approximate="tanh"' attribute #18750

@huenwei-arch

Description

@huenwei-arch

Expected behavior

The TVM ONNX frontend should correctly handle the approximate attribute for the Gelu operator (introduced in Opset 20).

  • If approximate="none" (default), it should map to relax.op.nn.gelu, using the exact CDF formula:
    $$y = 0.5x(1 + erf(\frac{x}{\sqrt{2}}))$$
  • If approximate="tanh", it should map to relax.op.nn.gelu_tanh, using the Tanh approximation:
    $$y = 0.5x(1 + \tanh(\sqrt{\frac{2}{\pi}}(x + 0.044715x^3)))$$

Actual behavior

The TVM ONNX frontend currently ignores the approximate attribute and hardcodes the mapping to R.nn.gelu (the exact version).

This leads to a systematic numerical mismatch when the source model expects the Tanh-based approximation. For example, at $x = -1.0$:

  • Exact (TVM current): $\approx -0.158655$
  • Tanh Approx (Expected): $\approx -0.158808$
  • Delta: $\approx 1.5 \times 10^{-4}$ (Exceeds float32 tolerance for identical algorithms).

Observed Relax IR (incorrectly mapped to gelu):

@R.function
def main(X: R.Tensor((4,), dtype="float32")) -> R.Tensor((4,), dtype="float32"):
    R.func_attr({"num_input": 1})
    with R.dataflow():
        # Bug: Should be R.nn.gelu_tanh when approximate="tanh"
        gv: R.Tensor((4,), dtype="float32") = R.nn.gelu(X) 
        R.output(gv)
    return gv

Environment

OS:
Linux-5.15.0-139-generic-x86_64-with-glibc2.31

TVM Version:
0.24.dev0

ONNX Version:
1.20.1

ONNX Runtime Version:
1.24.1

NumPy Version:
2.4.2

Steps to reproduce

import onnx
from onnx import helper, TensorProto
import numpy as np
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import onnxruntime as ort

def reproduce_gelu_tanh_mismatch():
    """
    Reproduce the issue where TVM ONNX frontend ignores the 'approximate' attribute
    of the Gelu operator (Opset 20).
    """
    # 1. Create an ONNX model with Gelu (approximate='tanh')
    # According to ONNX Opset 20, this should trigger the Tanh-based approximation.
    dtype_proto = TensorProto.FLOAT
    dtype_np = 'float32'

    node_def = helper.make_node(
        'Gelu',
        inputs=['X'],
        outputs=['Y'],
        approximate='tanh'
    )

    graph_def = helper.make_graph(
        [node_def],
        'gelu-tanh-repro',
        [helper.make_tensor_value_info('X', dtype_proto, [4])],
        [helper.make_tensor_value_info('Y', dtype_proto, [4])],
    )

    opset_info = helper.make_opsetid("", 20)
    model_def = helper.make_model(
        graph_def,
        ir_version=9,
        opset_imports=[opset_info]
    )
    
    # 2. Prepare input data
    x_np = np.array([-1.0, 0.0, 1.0, 2.0], dtype=dtype_np)

    # 3. Get reference output from ONNX Runtime
    sess = ort.InferenceSession(model_def.SerializeToString())
    onnx_outputs = sess.run(None, {'X': x_np})[0]

    # 4. Get TVM Relax output
    tvm_mod = from_onnx(model_def)
    
    print("--- Generated Relax IR ---")
    print(tvm_mod["main"].script())

    target = tvm.target.Target("llvm")
    dev = tvm.cpu(0)
    ex = relax.build(tvm_mod, target)
    vm = relax.VirtualMachine(ex, dev)
    
    tvm_outputs = vm["main"](x_np).numpy()

    # 5. Compare Results
    print(f"Reference (ORT): {onnx_outputs}")
    print(f"TVM Relax:       {tvm_outputs}")

    try:
        np.testing.assert_allclose(
            onnx_outputs, 
            tvm_outputs, 
            rtol=1e-5, atol=1e-5,
            err_msg="Numerical mismatch! TVM likely ignored approximate='tanh'."
        )
        print("\nResult: Success (No mismatch detected)")
    except AssertionError as e:
        print("\nResult: Bug Confirmed (Numerical mismatch detected)")
        print(e)

if __name__ == "__main__":
    reproduce_gelu_tanh_mismatch()

Triage

  • relax:frontend:onnx
  • needs-triage

cc @KJlaccHoeUM9l

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions