From a55e894116c0672258db936dbc7702440d4b5c27 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Thu, 27 Feb 2025 14:23:40 -0500 Subject: [PATCH] add SpGEMM update function --- include/spblas/algorithms/multiply.hpp | 3 ++ include/spblas/algorithms/multiply_impl.hpp | 47 +++++++++++++++++++++ test/gtest/spgemm_test.cpp | 45 ++++++++++++++++++++ 3 files changed, 95 insertions(+) diff --git a/include/spblas/algorithms/multiply.hpp b/include/spblas/algorithms/multiply.hpp index f15748e..a6cfed9 100644 --- a/include/spblas/algorithms/multiply.hpp +++ b/include/spblas/algorithms/multiply.hpp @@ -26,4 +26,7 @@ void multiply_compute(operation_info_t& info, A&& a, B&& b, C&& c); template void multiply_fill(operation_info_t& info, A&& a, B&& b, C&& c); +template +void multiply_fill_update(operation_info_t& info, A&& a, B&& b, C&& c); + } // namespace spblas diff --git a/include/spblas/algorithms/multiply_impl.hpp b/include/spblas/algorithms/multiply_impl.hpp index 174bc08..386452e 100644 --- a/include/spblas/algorithms/multiply_impl.hpp +++ b/include/spblas/algorithms/multiply_impl.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -147,6 +148,45 @@ operation_info_t multiply_compute(A&& a, B&& b, C&& c) { return operation_info_t{__backend::shape(c), nnz}; } +// C = AB +// SpGEMM (Gustavson's Algorithm) on existing C values +template + requires(__backend::row_iterable && __backend::row_iterable && + __detail::is_csr_view_v) +void multiply_update(A&& a, B&& b, C&& c) { + log_trace(""); + if (__backend::shape(a)[0] != __backend::shape(c)[0] || + __backend::shape(b)[1] != __backend::shape(c)[1] || + __backend::shape(a)[1] != __backend::shape(b)[0]) { + throw std::invalid_argument( + "multiply: matrix dimensions are incompatible."); + } + + using T = tensor_scalar_t; + using I = tensor_index_t; + using O = tensor_offset_t; + + auto c_base = __detail::get_ultimate_base(c); + const auto c_rowptr = c_base.rowptr(); + const auto c_colind = c_base.colind(); + const auto c_values = c_base.values(); + + for (auto&& [i, a_row] : __backend::rows(a)) { + std::unordered_map c_columns; + const auto c_begin = c_rowptr[i]; + const auto c_end = c_rowptr[i + 1]; + for (auto c_nz : __ranges::views::iota(c_begin, c_end)) { + c_columns.emplace(c_colind[c_nz], c_nz); + c_values[c_nz] = 0; + } + for (auto&& [k, a_v] : a_row) { + for (auto&& [j, b_v] : __backend::lookup_row(b, k)) { + c_values[c_columns[j]] += a_v * b_v; + } + } + } +} + template requires(__backend::row_iterable && __backend::row_iterable && __detail::is_csr_view_v) @@ -163,4 +203,11 @@ void multiply_fill(operation_info_t info, A&& a, B&& b, C&& c) { multiply(a, b, c); } +// C = AB after multiply_fill(info, A, B, C) was called previously +template +void multiply_fill_update(operation_info_t info, A&& a, B&& b, C&& c) { + log_trace(""); + multiply_update(a, b, c); +} + } // namespace spblas diff --git a/test/gtest/spgemm_test.cpp b/test/gtest/spgemm_test.cpp index cff2322..74ee325 100644 --- a/test/gtest/spgemm_test.cpp +++ b/test/gtest/spgemm_test.cpp @@ -70,6 +70,51 @@ TEST(CsrView, SpGEMM) { } } +TEST(CsrView, SpGEMMUpdate) { + using T = float; + using I = spblas::index_t; + using O = spblas::offset_t; + + for (auto&& [m, k, nnz] : util::dims) { + for (auto&& n : {m, k}) { + auto [a_values, a_rowptr, a_colind, a_shape, a_nnz] = + spblas::generate_csr(m, k, nnz); + + auto [b_values, b_rowptr, b_colind, b_shape, b_nnz] = + spblas::generate_csr(k, n, nnz); + + spblas::csr_view a(a_values, a_rowptr, a_colind, a_shape, a_nnz); + spblas::csr_view b(b_values, b_rowptr, b_colind, b_shape, b_nnz); + + std::vector c_rowptr(m + 1); + + spblas::csr_view c(nullptr, c_rowptr.data(), nullptr, {m, n}, 0); + + auto info = spblas::multiply_compute(a, b, c); + + std::vector c_values(info.result_nnz()); + std::vector c_ref_values(info.result_nnz()); + std::vector c_colind(info.result_nnz()); + + spblas::csr_view c_ref(c_ref_values.data(), c_rowptr.data(), + c_colind.data(), {m, n}, info.result_nnz()); + c.update(c_values, c_rowptr, c_colind); + + spblas::__ranges::transform(a_values, a_values.begin(), + [](auto value) { return 1 / value; }); + spblas::__ranges::transform(b_values, b_values.begin(), + [](auto value) { return 1 / value; }); + + spblas::multiply_fill(info, a, b, c_ref); + spblas::multiply_fill_update(info, a, b, c); + + for (auto i : spblas::__ranges::views::iota(O{}, c.size())) { + EXPECT_EQ_(c_values[i], c_ref_values[i]); + } + } + } +} + TEST(CsrView, SpGEMM_AScaled) { using T = float; using I = spblas::index_t;