diff --git a/include/matx/operators/hermitian.h b/include/matx/operators/hermitian.h index 716a7c13..2e6f31be 100644 --- a/include/matx/operators/hermitian.h +++ b/include/matx/operators/hermitian.h @@ -139,6 +139,8 @@ namespace matx return (dim < (Rank() - 2)) ? op_.Size(dim) : op_.Size((dim == Rank() - 1) ? Rank() - 2 : Rank() - 1); } + __MATX_INLINE__ const auto &InputOp() const noexcept { return op_; } + template __MATX_INLINE__ void PreRun(ShapeType &&shape, Executor &&ex) const noexcept { diff --git a/include/matx/operators/unary_operators.h b/include/matx/operators/unary_operators.h index 7b46e76b..349e41e2 100644 --- a/include/matx/operators/unary_operators.h +++ b/include/matx/operators/unary_operators.h @@ -158,6 +158,9 @@ namespace matx } #endif + __MATX_INLINE__ const auto &InputOp() const noexcept { return in1_; } + __MATX_INLINE__ const auto &UnaryOp() const noexcept { return op_; } + template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType &in) const { if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { diff --git a/include/matx/transforms/matmul/matmul_cuda.h b/include/matx/transforms/matmul/matmul_cuda.h index 71060058..0a8bdfe4 100644 --- a/include/matx/transforms/matmul/matmul_cuda.h +++ b/include/matx/transforms/matmul/matmul_cuda.h @@ -47,6 +47,9 @@ #include "matx/core/error.h" #include "matx/core/nvtx.h" #include "matx/core/tensor.h" +#include "matx/operators/hermitian.h" +#include "matx/operators/transpose.h" +#include "matx/operators/unary_operators.h" #include "matx/transforms/matmul/matmul_common.h" namespace matx { @@ -69,6 +72,48 @@ namespace detail { // Configurable tensor rank threshold for single batch operation static constexpr int MATMUL_BATCH_RANK_THRESHOLD = 4; +template +struct MatMulInputCanonicalized_t { + Op op; + bool forceOp = false; + cublasOperation_t forcedOp = CUBLAS_OP_N; +}; + +template +__MATX_INLINE__ auto CanonicalizeMatMulInput(const Op &op) +{ + return MatMulInputCanonicalized_t{op, false, CUBLAS_OP_N}; +} + +template +__MATX_INLINE__ auto CanonicalizeMatMulInput(const HermitianTransOp &op) +{ + constexpr auto forced_op = + is_complex_v ? CUBLAS_OP_C : CUBLAS_OP_T; + return MatMulInputCanonicalized_t>{ + op.InputOp(), true, forced_op}; +} + +template +__MATX_INLINE__ auto CanonicalizeMatMulInput( + const matxUnaryOp> &op) +{ + // Keep a stable return type for template deduction and only toggle forceOp. + bool force_op = false; + if constexpr (is_tensor_view_v>) { + const auto &in = op.InputOp(); + // Only remap when conj is applied to a transposed tensor view. A plain + // conj(X) must keep its original layout and should not be force-lowered. + force_op = (in.Stride(in.Rank() - 2) == 1 && in.Size(in.Rank() - 1) != 1); + } + + constexpr auto forced_op = + is_complex_v ? CUBLAS_OP_C : CUBLAS_OP_T; + return MatMulInputCanonicalized_t>>{ + op, force_op, force_op ? forced_op : CUBLAS_OP_N}; +} + + typedef enum { MEM_ORDER_ROW_MAJOR = 0, MEM_ORDER_COL_MAJOR = 1, @@ -138,6 +183,12 @@ struct MatMulCUDAParams_t { MatXDataType_t dtype; cublasOperation_t opA; cublasOperation_t opB; + bool applyOpA = false; + bool applyOpB = false; + bool conjOpA = false; + bool conjOpB = false; + bool a_col_major = false; + bool b_col_major = false; }; template = 2); - MATX_ASSERT(a.Size(TensorTypeA::Rank() - 1) == b.Size(TensorTypeB::Rank() - 2), matxInvalidSize); - MATX_ASSERT(c.Size(RANK - 1) == b.Size(TensorTypeB::Rank() - 1), matxInvalidSize); - MATX_ASSERT(c.Size(RANK - 2) == a.Size(TensorTypeA::Rank() - 2), matxInvalidSize); + const auto is_transpose_op = [](cublasOperation_t op) { + return op == CUBLAS_OP_T || op == CUBLAS_OP_C; + }; + + const bool a_is_transposed = forceOpA && is_transpose_op(forcedOpA); + const bool b_is_transposed = forceOpB && is_transpose_op(forcedOpB); + + const index_t a_rows = + a_is_transposed ? a.Size(TensorTypeA::Rank() - 1) + : a.Size(TensorTypeA::Rank() - 2); + const index_t a_cols = + a_is_transposed ? a.Size(TensorTypeA::Rank() - 2) + : a.Size(TensorTypeA::Rank() - 1); + const index_t b_rows = + b_is_transposed ? b.Size(TensorTypeB::Rank() - 1) + : b.Size(TensorTypeB::Rank() - 2); + const index_t b_cols = + b_is_transposed ? b.Size(TensorTypeB::Rank() - 2) + : b.Size(TensorTypeB::Rank() - 1); + + MATX_ASSERT(a_cols == b_rows, matxInvalidSize); + MATX_ASSERT(c.Size(RANK - 1) == b_cols, matxInvalidSize); + MATX_ASSERT(c.Size(RANK - 2) == a_rows, matxInvalidSize); // Ensure batch dimensions are equal for (int i = 0; i < RANK - 2; i++) { @@ -203,7 +277,7 @@ class MatMulCUDAHandle_t { } // This must come before the things below to properly set class parameters - params_ = GetGemmParams(c, a, b); + params_ = GetGemmParams(c, a, b, forceOpA, forcedOpA, forceOpB, forcedOpB); if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) { // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB for Hopper+: @@ -260,7 +334,10 @@ class MatMulCUDAHandle_t { } static detail::MatMulCUDAParams_t GetGemmParams(TensorTypeC &c, const TensorTypeA &a, - const TensorTypeB &b) + const TensorTypeB &b, bool forceOpA = false, + cublasOperation_t forcedOpA = CUBLAS_OP_N, + bool forceOpB = false, + cublasOperation_t forcedOpB = CUBLAS_OP_N) { /* If a user passes in a tensor where the last two dimensions are transposed we retain the original size parameters, but tell the underlying libraries that the tensors are @@ -367,21 +444,27 @@ class MatMulCUDAHandle_t { if constexpr (is_complex_half_v) { // For half complex we always copy to a new tensor so it is always cublas op N params.opA = CUBLAS_OP_N; + params.a_col_major = false; } else if ( a.Stride(TensorTypeA::Rank()-1) > 1 // last stride > 1 || (a.Stride(TensorTypeA::Rank()-1) == 1 && a.Stride(TensorTypeA::Rank()-2) == 1 && a.Size(TensorTypeA::Rank()-1) != 1)) { // last strides both equal 1 and size > 1 params.opA = CUBLAS_OP_T; + params.a_col_major = true; } else { // otherwise row major params.opA = CUBLAS_OP_N; + params.a_col_major = false; } if constexpr (is_complex_half_v) { // For half complex we always copy to a new tensor so it is always cublas op N params.opB = CUBLAS_OP_N; + params.b_col_major = false; } else if ( b.Stride(TensorTypeB::Rank()-1) > 1 // last stride > 1 || (b.Stride(TensorTypeB::Rank()-1) == 1 && b.Stride(TensorTypeB::Rank()-2) == 1 && b.Size(TensorTypeB::Rank()-1) != 1)) { // last strides both equal 1 and size > 1 params.opB = CUBLAS_OP_T; + params.b_col_major = true; } else { // otherwise row major params.opB = CUBLAS_OP_N; + params.b_col_major = false; } params.a_rows = a.Size(TensorTypeA::Rank() - 2); @@ -418,8 +501,11 @@ class MatMulCUDAHandle_t { params.lda = a.Size(TensorTypeA::Rank()-1); } - params.c_rows = params.a_rows; - params.c_cols = params.b_cols; + // C descriptor dimensions must always match the output tensor view. + // Forced transpose/hermitian modes can change matmul op semantics + // without changing A/B storage descriptors. + params.c_rows = c.Size(RANK - 2); + params.c_cols = c.Size(RANK - 1); params.ldc = c.Stride(RANK - 2); } @@ -436,6 +522,42 @@ class MatMulCUDAHandle_t { } } + if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) { + if (forceOpA) { + if constexpr (is_complex_half_v) { + if (forcedOpA == CUBLAS_OP_C) { + // cublasLt complex-half does not reliably support OP_C. Encode + // hermitian as OP_T and fold conjugation into planar packing. + params.opA = CUBLAS_OP_T; + params.conjOpA = true; + } + else { + params.opA = forcedOpA; + } + } + else { + params.opA = forcedOpA; + } + params.applyOpA = true; + } + + if (forceOpB) { + if constexpr (is_complex_half_v) { + if (forcedOpB == CUBLAS_OP_C) { + params.opB = CUBLAS_OP_T; + params.conjOpB = true; + } + else { + params.opB = forcedOpB; + } + } + else { + params.opB = forcedOpB; + } + params.applyOpB = true; + } + } + return params; } @@ -558,17 +680,21 @@ class MatMulCUDAHandle_t { cublasLtOrder_t rowOrder = CUBLASLT_ORDER_ROW; cublasLtOrder_t colOrder = CUBLASLT_ORDER_COL; - auto op = CUBLAS_OP_N; + const cublasOperation_t transA = + params_.applyOpA ? params_.opA : CUBLAS_OP_N; + const cublasOperation_t transB = + params_.applyOpB ? params_.opB : CUBLAS_OP_N; + // A operation ret = cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op, - sizeof(op)); + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transA, + sizeof(transA)); MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); // B operation ret = cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op, - sizeof(op)); + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transB, + sizeof(transB)); MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); // Update this later when we're more flexible on compute type @@ -609,7 +735,7 @@ class MatMulCUDAHandle_t { MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); // Matrix data order - if (params_.opA == CUBLAS_OP_T) { + if (params_.a_col_major) { ret = cublasLtMatrixLayoutSetAttribute( Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &colOrder, sizeof(colOrder)); @@ -621,7 +747,7 @@ class MatMulCUDAHandle_t { } MATX_ASSERT(ret == CUBLAS_STATUS_SUCCESS, matxMatMulError); - if (params_.opB == CUBLAS_OP_T) { + if (params_.b_col_major) { ret = cublasLtMatrixLayoutSetAttribute( Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &colOrder, sizeof(colOrder)); @@ -792,9 +918,22 @@ class MatMulCUDAHandle_t { } auto c_planar = make_tensor(reinterpret_cast(c_hp), c_shape, false); - // Convert A/B to planar layout - (a_planar = planar(a)).run(stream); - (b_planar = planar(b)).run(stream); + // Convert A/B to planar layout. For complex-half hermitian lowering, fold + // conjugation into this conversion so no separate hermitian temporary is + // materialized. + if (params_.conjOpA) { + (a_planar = planar(conj(a))).run(stream); + } + else { + (a_planar = planar(a)).run(stream); + } + + if (params_.conjOpB) { + (b_planar = planar(conj(b))).run(stream); + } + else { + (b_planar = planar(b)).run(stream); + } // update pointers to planar data. // must use Reset because types for planar are different @@ -1089,7 +1228,15 @@ struct MatMulCUDAParamsKeyHash { return std::hash()(k.m) + std::hash()(k.n) + std::hash()(k.k) + std::hash()(k.batch) + std::hash()(k.prov) + - std::hash()((size_t)k.stream); + std::hash()((size_t)k.stream) + + std::hash()(static_cast(k.opA)) + + std::hash()(static_cast(k.opB)) + + std::hash()(static_cast(k.applyOpA)) + + std::hash()(static_cast(k.applyOpB)) + + std::hash()(static_cast(k.conjOpA)) + + std::hash()(static_cast(k.conjOpB)) + + std::hash()(static_cast(k.a_col_major)) + + std::hash()(static_cast(k.b_col_major)); } }; @@ -1109,7 +1256,10 @@ struct MatMulCUDAParamsKeyEq { l.stream == t.stream && l.lda == t.lda && l.ldb == t.ldb && l.ldc == t.ldc && l.batch == t.batch && l.prov == t.prov && l.dtype == t.dtype && l.opA == t.opA && - l.opB == t.opB && l.rank == t.rank; + l.opB == t.opB && l.rank == t.rank && + l.applyOpA == t.applyOpA && l.applyOpB == t.applyOpB && + l.conjOpA == t.conjOpA && l.conjOpB == t.conjOpB && + l.a_col_major == t.a_col_major && l.b_col_major == t.b_col_major; } }; @@ -1185,9 +1335,16 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A, static_assert(is_a_complex || is_b_complex, "If C is complex then either A or B should be complex "); } - // promote A and B to the type of C - auto A_ = as_type(A); - auto B_ = as_type(B); + // Canonicalize first so wrappers like hermitianT/conj(transpose(...)) are + // preserved even after type promotion to C's value type. + auto aCanonIn = detail::CanonicalizeMatMulInput(A); + auto bCanonIn = detail::CanonicalizeMatMulInput(B); + auto A_ = as_type(aCanonIn.op); + auto B_ = as_type(bCanonIn.op); + auto aCanon = detail::MatMulInputCanonicalized_t{ + A_, aCanonIn.forceOp, aCanonIn.forcedOp}; + auto bCanon = detail::MatMulInputCanonicalized_t{ + B_, bCanonIn.forceOp, bCanonIn.forcedOp}; static_assert(detail::CompatibleGemmCUDATypes(), "Combination of A/B/C types are not supported"); @@ -1195,19 +1352,19 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A, // CublasLt does not support operators and certain transpose modes. // Grab a suppported tensor here and copy in if necessary. auto c = getCublasSupportedTensor(C, stream); - auto a = getCublasSupportedTensor(A_, stream); - auto b = getCublasSupportedTensor(B_, stream); + auto a = getCublasSupportedTensor(aCanon.op, stream); + auto b = getCublasSupportedTensor(bCanon.op, stream); typedef decltype(c) ctype; typedef decltype(a) atype; typedef decltype(b) btype; - if(!is_matx_transform_op() && !a.isSameView(A_)) { - (a = A_).run(stream); + if(!is_matx_transform_op() && !a.isSameView(aCanon.op)) { + (a = aCanon.op).run(stream); } - if(!is_matx_transform_op() && !b.isSameView(B_)) { - (b = B_).run(stream); + if(!is_matx_transform_op() && !b.isSameView(bCanon.op)) { + (b = bCanon.op).run(stream); } if(beta != 0 && !c.isSameView(C)) { @@ -1226,7 +1383,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A, { // Get parameters required by these tensors auto params = - detail::MatMulCUDAHandle_t::GetGemmParams(c, a, b); + detail::MatMulCUDAHandle_t::GetGemmParams( + c, a, b, + (PROV == PROVIDER_TYPE_CUBLASLT) && aCanon.forceOp, aCanon.forcedOp, + (PROV == PROVIDER_TYPE_CUBLASLT) && bCanon.forceOp, bCanon.forcedOp); params.stream = stream; using cache_val_type = detail::MatMulCUDAHandle_t; @@ -1236,7 +1396,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A, cache_id, params, [&]() { - return std::make_shared(c, a, b); + return std::make_shared( + c, a, b, + (PROV == PROVIDER_TYPE_CUBLASLT) && aCanon.forceOp, aCanon.forcedOp, + (PROV == PROVIDER_TYPE_CUBLASLT) && bCanon.forceOp, bCanon.forcedOp); }, [&](std::shared_ptr cache_type) { cache_type->Exec(c, a, b, stream, alpha, beta); diff --git a/test/00_transform/MatMul.cu b/test/00_transform/MatMul.cu index de008f12..9c7c7d37 100644 --- a/test/00_transform/MatMul.cu +++ b/test/00_transform/MatMul.cu @@ -199,6 +199,50 @@ TYPED_TEST(MatMulTestFloatTypes, SmallRectBTranspose) MATX_EXIT_HANDLER(); } +TYPED_TEST(MatMulTestFloatNonHalfTypes, HermitianOperandFusionEquivalent) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + if constexpr (!detail::CheckMatMulSupport()) { + GTEST_SKIP(); + } else if constexpr (!is_complex_v) { + GTEST_SKIP(); + } else { + constexpr index_t m = 4; + constexpr index_t k = 8; + constexpr index_t n = 16; + + tensor_t a{{k, m}}; + tensor_t b{{k, n}}; + tensor_t c_expr{{m, n}}; + tensor_t c_herm{{m, n}}; + tensor_t c_temp{{m, n}}; + tensor_t a_temp{{m, k}}; + + this->pb->template InitAndRunTVGenerator( + "00_transforms", "matmul_operators", "run_a_transpose", {m, k, n}); + + this->pb->NumpyToTensorView(a, "a"); + this->pb->NumpyToTensorView(b, "b"); + + (c_expr = matmul(conj(transpose_matrix(a)), b)).run(this->exec); + (c_herm = matmul(hermitianT(a), b)).run(this->exec); + (a_temp = conj(transpose_matrix(a))).run(this->exec); + (c_temp = matmul(a_temp, b)).run(this->exec); + + this->exec.sync(); + + for (index_t i = 0; i < m; i++) { + for (index_t j = 0; j < n; j++) { + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_temp(i, j), c_expr(i, j), this->thresh)); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(c_temp(i, j), c_herm(i, j), this->thresh)); + } + } + } + MATX_EXIT_HANDLER(); +} + TYPED_TEST(MatMulTestFloatNonHalfTypes, SmallRectCTranspose) { MATX_ENTER_HANDLER();