Skip to content

dtype_to_str: unsupported dtype raises ValueError for fp16/float8_e5m2 #4

@pandacooming

Description

@pandacooming

Problem

In tile_kernels/testing/bench.py, the dtype_to_str() function is incomplete:

def dtype_to_str(dtype: torch.dtype) -> str:
    mapping = {
        torch.float32: 'fp32',
        torch.bfloat16: 'bf16',
        torch.float8_e4m3fn: 'e4m3',
        torch.int8: 'e2m1',  # int8 represents FP4 e2m1 format
    }
    if dtype not in mapping:
        raise ValueError(f'Unsupported dtype: {dtype}. Only fp32, bf16, e4m3, and int8(e2m1) are supported')
    return mapping[dtype]

Missing mappings:

  • torch.float16'fp16'
  • torch.float8_e5m2'e5m2'

When torch.float16 or torch.float8_e5m2 is passed, it raises a ValueError.

Context

TileKernels includes quantization kernels (tile_kernels/quant/) that use float16 and float8_e5m2 dtypes. The dtype_to_str function is used in benchmark output formatting (via _format_valuemake_param_id). If these dtypes are used in a benchmark, the function will crash instead of producing a human-readable string.

Expected Fix

Add the two missing mappings to the mapping dict and update the error message accordingly:

mapping = {
    torch.float32: 'fp32',
    torch.float16: 'fp16',
    torch.bfloat16: 'bf16',
    torch.float8_e4m3fn: 'e4m3',
    torch.float8_e5m2: 'e5m2',
    torch.int8: 'e2m1',
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions