Skip to content

Add automatic hermitian transformation into cuBLAS from MatX operators#1137

Open
cliffburdick wants to merge 2 commits intomainfrom
hermitian_blas
Open

Add automatic hermitian transformation into cuBLAS from MatX operators#1137
cliffburdick wants to merge 2 commits intomainfrom
hermitian_blas

Conversation

@cliffburdick
Copy link
Collaborator

@cliffburdick cliffburdick commented Mar 10, 2026

When a user writes matmul(hermitianT(A), B) MatX currently creates a temporary, writes the hermitian to it, then passes it to cuBLAS. Instead we can detect this case and use the cuBLAS transformation, making the GEMM much faster.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cliffburdick
Copy link
Collaborator Author

/build

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR adds automatic fusion of hermitianT(A) and conj(transpose_matrix(A)) operands directly into the cuBLASLt GEMM descriptor, avoiding a temporary materialization and improving performance significantly. The implementation introduces CanonicalizeMatMulInput overloads that either strip the hermitian wrapper (returning the inner tensor + a forceOp flag) or retain the lazy operator, and threads the flags through GetGemmParams/MatMulCUDAHandle_t all the way into the cublasLt operation descriptor and matrix layout setup.

Key changes:

  • HermitianTransOp::InputOp() and matxUnaryOp::InputOp()/UnaryOp() accessors added to enable wrapper introspection.
  • CanonicalizeMatMulInput template specializations detect hermitian/conj-transpose patterns and surface a forceOp/forcedOp pair.
  • MatMulCUDAParams_t gains applyOpA/B, conjOpA/B, and a/b_col_major fields to decouple storage layout from GEMM operation descriptor.
  • cublasLt descriptor setup now consults applyOpA/B independently of the matrix storage order flags.
  • Confirmed regression: When the output tensor c is column-major, matmul_impl recurses via matmul_impl(transpose_matrix(c), transpose_matrix(b), transpose_matrix(a), ...) as a transpose-swap identity. Because CanonicalizeMatMulInput(HermitianTransOp<...>) strips the wrapper and isSameView skips materialization, the a passed to this recursion is the raw inner tensor — missing the conjugation entirely. Pre-PR this was safe because getCublasSupportedTensor(hermitianT(A)) always returned a new tensor, forcing materialization. The new test does not exercise a column-major output tensor, so this path is currently untested.

Confidence Score: 2/5

  • Not safe to merge without addressing the column-major C correctness regression.
  • The cublasLt fast path works correctly for the common row-major C case, which is what the new test exercises. However, there is a confirmed correctness regression for column-major output tensors when combined with a hermitian input: the hermitian wrapper is stripped before the col-major recursion fires, the conjugation flag is not forwarded into the recursive call, and the result silently computes a plain transpose instead of a conjugate transpose. This could produce numerically wrong outputs in any codebase that writes to a column-major (or transposed) output view together with hermitianT inputs.
  • include/matx/transforms/matmul/matmul_cuda.h — specifically the col-major C recursion at line 1378–1381 and the missing guard for forceOp flags.

Important Files Changed

Filename Overview
include/matx/transforms/matmul/matmul_cuda.h Core of the PR — adds CanonicalizeMatMulInput overloads for HermitianTransOp and conj(transpose), forwards forceOp flags through GetGemmParams/MatMulCUDAHandle_t, and patches cublasLt descriptor setup to use the fused operation. Contains a correctness regression: when output C is column-major, the recursive col-major-swap path at line 1380 receives a raw (un-conjugated) inner tensor because the hermitian wrapper was stripped during canonicalization and the forceOp flag is not carried into the recursive call.
include/matx/operators/hermitian.h Adds InputOp() accessor to expose the inner operand of HermitianTransOp, enabling the matmul canonicalization to unwrap the wrapper without copying. Change is minimal and correct.
include/matx/operators/unary_operators.h Adds InputOp() and UnaryOp() accessors to matxUnaryOp, needed so the conj(transpose) canonicalization can inspect the inner operand and operation type. Change is minimal and correct.
test/00_transform/MatMul.cu Adds HermitianOperandFusionEquivalent test verifying that conj(transpose(A))@b, hermitianT(A)@b, and an explicit materialized hermitian all produce the same result. Test is correct but does not cover the column-major output-tensor path or batched inputs, leaving the regression at line 1380 undetected.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["matmul_impl(C, A, B)"] --> B["CanonicalizeMatMulInput(A/B)"]
    B --> C{"Is A hermitianT?"}
    C -->|Yes| D["Strip wrapper → inner tensor\nforceOp=true, forcedOp=CUBLAS_OP_C"]
    C -->|No| E{"Is A conj(transpose(x))?"}
    E -->|Yes| F["Keep full lazy op\nforceOp=true, forcedOp=CUBLAS_OP_C"]
    E -->|No| G["forceOp=false, forcedOp=CUBLAS_OP_N"]

    D --> H["getCublasSupportedTensor(inner_tensor)"]
    F --> I["getCublasSupportedTensor(conj_transpose_op)"]
    G --> J["getCublasSupportedTensor(A_)"]

    H --> K{"isSameView?"}
    I --> L{"isSameView?"}
    J --> M{"isSameView?"}

    K -->|Yes - skip| N["a = inner tensor (NO conjugation applied)"]
    K -->|No| O["a = materialize inner tensor"]
    L -->|No - always new tensor| P["a = materialize conj(transpose) ✓"]
    M --> Q["a = A or copy"]

    N --> R{"c is col-major?"}
    P --> R
    Q --> R

    R -->|Yes - BUG for hermitianT| S["recurse: matmul_impl(T(c), T(b), T(a))\nforceOp NOT forwarded\nhermitian conjugation LOST ❌"]
    R -->|No| T["GetGemmParams with forceOpA/B\n(gated on PROV==CUBLASLT)"]
    T --> U["Set CUBLASLT_MATMUL_DESC_TRANSA/B\napplyOpA ? opA : CUBLAS_OP_N"]
Loading

Comments Outside Diff (1)

  1. include/matx/transforms/matmul/matmul_cuda.h, line 1378-1381 (link)

    Hermitian conjugation lost in column-major C recursion

    When the output tensor c is column-major, the code at line 1380 recurses by swapping and transposing A and B. At this point, a is the raw inner tensor returned by CanonicalizeMatMulInput(HermitianTransOp<...>) — the hermitian wrapper was stripped, the conjugation flag (aCanon.forceOp) is not forwarded into the recursive call, and a.isSameView(aCanon.op) was true so no materialization happened.

    The recursion therefore computes matmul(T(B), T(A_inner)) where A_inner is the plain data without conjugation — silently giving A^T instead of A^H for complex types.

    Pre-PR behavior was correct: getCublasSupportedTensor(hermitianT(A), stream) returned a new allocation, !isSameView triggered (a = hermitianT(A)).run(stream), so a held the fully materialized hermitian before this recursion fired.

    The simplest fix is to detect the forced-op case and fall through to the main GEMM path (avoiding the recursion) or materialize hermitian/conj data into a temporary before the column-major branch, e.g.:

    // If forced ops are active, the inner tensor has already been extracted and
    // the column-major identity trick would silently drop the conjugation.
    // Fall through to the standard path which correctly applies forceOp.
    if (!aCanon.forceOp && !bCanon.forceOp &&
        c.Stride(c.Rank()-2) == 1 && c.Stride(c.Rank()-1) > 1) {
      matmul_impl(transpose_matrix(c), transpose_matrix(b), transpose_matrix(a), exec, alpha, beta);
    } else

    Note: the conj(transpose_matrix(A)) specialization is not affected because its aCanon.op retains the full lazy operator, so isSameView is false and the data is materialized before this branch.

Last reviewed commit: 0a81358

@cliffburdick
Copy link
Collaborator Author

/build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant