-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
The TVM ONNX frontend should correctly handle the approximate attribute for the Gelu operator (introduced in Opset 20).
- If
approximate="none"(default), it should map torelax.op.nn.gelu, using the exact CDF formula:
$$y = 0.5x(1 + erf(\frac{x}{\sqrt{2}}))$$ - If
approximate="tanh", it should map torelax.op.nn.gelu_tanh, using the Tanh approximation:
$$y = 0.5x(1 + \tanh(\sqrt{\frac{2}{\pi}}(x + 0.044715x^3)))$$
Actual behavior
The TVM ONNX frontend currently ignores the approximate attribute and hardcodes the mapping to R.nn.gelu (the exact version).
This leads to a systematic numerical mismatch when the source model expects the Tanh-based approximation. For example, at
-
Exact (TVM current):
$\approx -0.158655$ -
Tanh Approx (Expected):
$\approx -0.158808$ -
Delta:
$\approx 1.5 \times 10^{-4}$ (Exceedsfloat32tolerance for identical algorithms).
Observed Relax IR (incorrectly mapped to gelu):
@R.function
def main(X: R.Tensor((4,), dtype="float32")) -> R.Tensor((4,), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
# Bug: Should be R.nn.gelu_tanh when approximate="tanh"
gv: R.Tensor((4,), dtype="float32") = R.nn.gelu(X)
R.output(gv)
return gvEnvironment
OS:
Linux-5.15.0-139-generic-x86_64-with-glibc2.31
TVM Version:
0.24.dev0
ONNX Version:
1.20.1
ONNX Runtime Version:
1.24.1
NumPy Version:
2.4.2
Steps to reproduce
import onnx
from onnx import helper, TensorProto
import numpy as np
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import onnxruntime as ort
def reproduce_gelu_tanh_mismatch():
"""
Reproduce the issue where TVM ONNX frontend ignores the 'approximate' attribute
of the Gelu operator (Opset 20).
"""
# 1. Create an ONNX model with Gelu (approximate='tanh')
# According to ONNX Opset 20, this should trigger the Tanh-based approximation.
dtype_proto = TensorProto.FLOAT
dtype_np = 'float32'
node_def = helper.make_node(
'Gelu',
inputs=['X'],
outputs=['Y'],
approximate='tanh'
)
graph_def = helper.make_graph(
[node_def],
'gelu-tanh-repro',
[helper.make_tensor_value_info('X', dtype_proto, [4])],
[helper.make_tensor_value_info('Y', dtype_proto, [4])],
)
opset_info = helper.make_opsetid("", 20)
model_def = helper.make_model(
graph_def,
ir_version=9,
opset_imports=[opset_info]
)
# 2. Prepare input data
x_np = np.array([-1.0, 0.0, 1.0, 2.0], dtype=dtype_np)
# 3. Get reference output from ONNX Runtime
sess = ort.InferenceSession(model_def.SerializeToString())
onnx_outputs = sess.run(None, {'X': x_np})[0]
# 4. Get TVM Relax output
tvm_mod = from_onnx(model_def)
print("--- Generated Relax IR ---")
print(tvm_mod["main"].script())
target = tvm.target.Target("llvm")
dev = tvm.cpu(0)
ex = relax.build(tvm_mod, target)
vm = relax.VirtualMachine(ex, dev)
tvm_outputs = vm["main"](x_np).numpy()
# 5. Compare Results
print(f"Reference (ORT): {onnx_outputs}")
print(f"TVM Relax: {tvm_outputs}")
try:
np.testing.assert_allclose(
onnx_outputs,
tvm_outputs,
rtol=1e-5, atol=1e-5,
err_msg="Numerical mismatch! TVM likely ignored approximate='tanh'."
)
print("\nResult: Success (No mismatch detected)")
except AssertionError as e:
print("\nResult: Bug Confirmed (Numerical mismatch detected)")
print(e)
if __name__ == "__main__":
reproduce_gelu_tanh_mismatch()Triage
- relax:frontend:onnx
- needs-triage
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug