Skip to content

feat: Enforce ml_dtypes.bfloat16 for BF16 I/O in Python client#897

Open
yinggeh wants to merge 1 commit into
mainfrom
yinggeh/tri-801-deprecate-bf16-to-fp32-conversion-in-python-client-library
Open

feat: Enforce ml_dtypes.bfloat16 for BF16 I/O in Python client#897
yinggeh wants to merge 1 commit into
mainfrom
yinggeh/tri-801-deprecate-bf16-to-fp32-conversion-in-python-client-library

Conversation

@yinggeh
Copy link
Copy Markdown
Contributor

@yinggeh yinggeh commented May 15, 2026

Summary

NumPy has no native BF16. Currently, casting FP32 <=> BF16 magnifies accuracy loss. Enforce using ml_dtypes.bfloat16 for native BF16 support.

Example

For 0.01 + 0.01, BF16 <=> FP32 will lose precision during truncating, making the output value biases from the true result 0.02.

0.01(BF16) + 0.01(BF16) => 0.0200195 (BF16)
0.01(FP32->BF16) + 0.01(FP32->BF16) => 0.01989746 (BF16->FP32)

@yinggeh yinggeh changed the title feat: Require ml_dtypes.bfloat16 for BF16 in Python client (TRI-801) feat: Require ml_dtypes.bfloat16 for BF16 in Python client May 15, 2026
@yinggeh yinggeh requested review from mc-nv, mudit-eng, pskiran1 and whoisj and removed request for mc-nv May 15, 2026 03:14
@yinggeh yinggeh self-assigned this May 15, 2026
@yinggeh yinggeh added the enhancement New feature or request label May 15, 2026
@yinggeh yinggeh changed the title feat: Require ml_dtypes.bfloat16 for BF16 in Python client feat: Use ml_dtypes.bfloat16 for BF16 I/O in Python client (TRI-801) May 15, 2026
@yinggeh yinggeh changed the title feat: Use ml_dtypes.bfloat16 for BF16 I/O in Python client (TRI-801) feat: Use ml_dtypes.bfloat16 for BF16 I/O in Python client May 15, 2026
@yinggeh yinggeh changed the title feat: Use ml_dtypes.bfloat16 for BF16 I/O in Python client feat: Enforce ml_dtypes.bfloat16 for BF16 I/O in Python client May 15, 2026
dtype = np_to_triton_dtype(input_tensor.dtype)
if self._input.datatype != dtype:
error_message = (
"got unexpected datatype {} from numpy array, expected {}.".format(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use an interpolated string instead?

f"got unexpected datatype {dtype} from numpy array, expected {self._input.datatype}."

dtype = np_to_triton_dtype(input_tensor.dtype)
if self._datatype != dtype:
error_message = (
"got unexpected datatype {} from numpy array, expected {}.".format(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, why not an interpolated string?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Development

Successfully merging this pull request may close these issues.

2 participants