Skip to content

[Example] Backward function of RGMS kernel #77

@yzh119

Description

@yzh119

The forward function of the RGMS kernel is (relation related information are ignored for simplicity):

$$ Y = AXW $$

we already have its implementation written in SparseTIR using composable formats and tensor cores.

The backward function of the RGMS kernel needs to compute both the gradient of $X$ and $W$ :
$$\nabla (XW) = A^T \nabla Y$$
$$\nabla X = \nabla (XW) W^T $$
$$\nabla W = X^T \nabla (XW) $$

The three formulas could be computed inside the same kernel, and $\nabla (XW)$ should be stored in shared memory. The same optimizations (composable formats + tensorization) could be applied to backward kernel as well.

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions