Add automatic hermitian transformation into cuBLAS from MatX operators#1137
Add automatic hermitian transformation into cuBLAS from MatX operators#1137cliffburdick wants to merge 2 commits intomainfrom
Conversation
|
/build |
Greptile SummaryThis PR adds automatic fusion of Key changes:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["matmul_impl(C, A, B)"] --> B["CanonicalizeMatMulInput(A/B)"]
B --> C{"Is A hermitianT?"}
C -->|Yes| D["Strip wrapper → inner tensor\nforceOp=true, forcedOp=CUBLAS_OP_C"]
C -->|No| E{"Is A conj(transpose(x))?"}
E -->|Yes| F["Keep full lazy op\nforceOp=true, forcedOp=CUBLAS_OP_C"]
E -->|No| G["forceOp=false, forcedOp=CUBLAS_OP_N"]
D --> H["getCublasSupportedTensor(inner_tensor)"]
F --> I["getCublasSupportedTensor(conj_transpose_op)"]
G --> J["getCublasSupportedTensor(A_)"]
H --> K{"isSameView?"}
I --> L{"isSameView?"}
J --> M{"isSameView?"}
K -->|Yes - skip| N["a = inner tensor (NO conjugation applied)"]
K -->|No| O["a = materialize inner tensor"]
L -->|No - always new tensor| P["a = materialize conj(transpose) ✓"]
M --> Q["a = A or copy"]
N --> R{"c is col-major?"}
P --> R
Q --> R
R -->|Yes - BUG for hermitianT| S["recurse: matmul_impl(T(c), T(b), T(a))\nforceOp NOT forwarded\nhermitian conjugation LOST ❌"]
R -->|No| T["GetGemmParams with forceOpA/B\n(gated on PROV==CUBLASLT)"]
T --> U["Set CUBLASLT_MATMUL_DESC_TRANSA/B\napplyOpA ? opA : CUBLAS_OP_N"]
|
|
/build |
When a user writes
matmul(hermitianT(A), B)MatX currently creates a temporary, writes the hermitian to it, then passes it to cuBLAS. Instead we can detect this case and use the cuBLAS transformation, making the GEMM much faster.