diff --git a/exir/operator/util.py b/exir/operator/util.py index 23dc3edd302..fd900e6f635 100644 --- a/exir/operator/util.py +++ b/exir/operator/util.py @@ -55,7 +55,9 @@ def gen_out_variant_schema(func_op_schema: str) -> str: torch.ops.quantized_decomposed.choose_qparams.tensor, ] try: - import torchao # noqa: F401 + # Import quant_primitives directly to ensure custom ops are registered + # before accessing them via torch.ops.torchao + import torchao.quantization.quant_primitives # noqa: F401 _QUANT_PRIMITIVES.extend( [ @@ -64,5 +66,7 @@ def gen_out_variant_schema(func_op_schema: str) -> str: torch.ops.torchao.choose_qparams_affine.default, ] ) -except ImportError: +except (ImportError, AttributeError): + # ImportError: torchao or quant_primitives not installed + # AttributeError: torchao installed but operators not registered (e.g., older version) pass