The forward function of the RGMS kernel is (relation related information are ignored for simplicity):
$$ Y = AXW $$
we already have its implementation written in SparseTIR using composable formats and tensor cores.
The backward function of the RGMS kernel needs to compute both the gradient of $X$ and $W$ :
$$\nabla (XW) = A^T \nabla Y$$
$$\nabla X = \nabla (XW) W^T $$
$$\nabla W = X^T \nabla (XW) $$
The three formulas could be computed inside the same kernel, and $\nabla (XW)$ should be stored in shared memory. The same optimizations (composable formats + tensorization) could be applied to backward kernel as well.
The forward function of the RGMS kernel is (relation related information are ignored for simplicity):
we already have its implementation written in SparseTIR using composable formats and tensor cores.
The backward function of the RGMS kernel needs to compute both the gradient of$X$ and $W$ :
$$\nabla (XW) = A^T \nabla Y$$
$$\nabla X = \nabla (XW) W^T $$
$$\nabla W = X^T \nabla (XW) $$
The three formulas could be computed inside the same kernel, and$\nabla (XW)$ should be stored in shared memory. The same optimizations (composable formats + tensorization) could be applied to backward kernel as well.