Skip to content

[Tracking Issue] Support Axis Union/Intersection #72

@yzh119

Description

@yzh119

The Problem

Currently, SparseTIR does not support lowering code to co-iteration structure, whenever we want to add/multiply two sparse tensors/vectors, we need to create another axis to indicate the union/intersection of axes.

Here is an example of SpMSpV.

I = T.dense_fixed(m)
J = T.sparse_variable(I, (n, nnz), indptr_j, indices_j)
IV = T.dense_fixed(1)
JV = T.sparse_variable(IV, (n, nnz), indptr_jv, indices_jv)
J_and = T.sparse_variable(I, (n, nnz), indptr_j_and, indices_j_and)
A = T.match_sparse_buffer(a, (I, J))
B = T.match_sparse_buffer(b, (IV, JV))
with T.iter([I, J_and], "SR", "spmspv") as [i, j]:
    with T.init():
        C[i] = T.float32(0)
    C[i] = C[i] + A[i, j] * B[0, j]

SparseTIR would generate several binary blocks for indexing A and B because we do not have co-iterations yet, and we need mid arrays generated by binary search blocks to access A and B under the for-loop structure.

Once we support axis union/intersection and co-iteration structure generation, we can declare J_and as:

J_and = T.intersection([J, JV], indptr_j_and, indices_j_and)
J_or = T.union([J, JV], indptr_j_or, indices_j_or)

and sparse iterations on union/intersect axes can yield co-iteration structures in sparse iteration lowering pass.

Milestone

  • Support co-iteration structure (either w/ While construct, or create a new statement in TIR).
  • Support T.intersection/T.union, and possibly more general ones (consider SpGEMM).
  • Modify sparse iteration lowering pass.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requesthelp wantedExtra attention is needed

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions