we need `gpu` argument in stablehlo lowering rules for that or rewrite the custom call in PJRT ... Also, note the transpose trick via the layout that is used in JAX.
we need
gpuargument in stablehlo lowering rules for that or rewrite the custom call in PJRT ...Also, note the transpose trick via the layout that is used in JAX.