From 43238e355aa3e344aa73889dcb4000414ffe7833 Mon Sep 17 00:00:00 2001 From: Troy Alderson <58866654+tfalders@users.noreply.github.com> Date: Mon, 12 May 2025 14:21:03 -0600 Subject: [PATCH 1/2] Correction for bufferSize methods (#357) * Corrected bufferSize output * Corrected tests * Updated changelog * Addressed review comment * Bump .so version (cherry picked from commit 5e0e46099ebbfed33800890f18423a09675936db) --- CHANGELOG.md | 29 + clients/include/testing_gesvd.hpp | 15 +- clients/include/testing_gesvda.hpp | 13 +- clients/include/testing_gesvdj.hpp | 13 +- clients/include/testing_orgbr_ungbr.hpp | 10 +- clients/include/testing_orgqr_ungqr.hpp | 10 +- clients/include/testing_orgtr_ungtr.hpp | 10 +- clients/include/testing_ormqr_unmqr.hpp | 10 +- clients/include/testing_ormtr_unmtr.hpp | 10 +- clients/include/testing_syevd_heevd.hpp | 13 +- clients/include/testing_syevdx_heevdx.hpp | 14 +- clients/include/testing_syevj_heevj.hpp | 13 +- clients/include/testing_sygvd_hegvd.hpp | 15 +- clients/include/testing_sygvdx_hegvdx.hpp | 14 +- clients/include/testing_sygvj_hegvj.hpp | 13 +- clients/include/testing_sytrd_hetrd.hpp | 13 +- clients/include/testing_sytrf.hpp | 15 +- library/CMakeLists.txt | 2 +- library/src/amd_detail/hipsolver.cpp | 692 +++++++++++++++++----- 19 files changed, 697 insertions(+), 227 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0dc4481..8922a1c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,35 @@ Full documentation for hipSOLVER is available at the [hipSOLVER Documentation](https://rocm.docs.amd.com/projects/hipSOLVER/en/latest/index.html). +## (Unreleased) hipSOLVER + +### Added +### Changed +### Removed +### Optimized +### Resolved issues + +* Corrected the value of `lwork` returned by various `bufferSize` functions to be consistent with NVIDIA cuSOLVER. The following functions will + now return `lwork` such that the workspace size (in bytes) is `sizeof(T) * lwork`, rather than `lwork`: + * hipsolverXorgbr_bufferSize, hipsolverXorgqr_bufferSize, hipsolverXorgtr_bufferSize, hipsolverXormqr_bufferSize, hipsolverXormtr_bufferSize, + hipsolverXgesvd_bufferSize, hipsolverXgesvdj_bufferSize, hipsolverXgesvdBatched_bufferSize, hipsolverXgesvdaStridedBatched_bufferSize, + hipsolverXsyevd_bufferSize, hipsolverXsyevdx_bufferSize, hipsolverXsyevj_bufferSize, hipsolverXsyevjBatched_bufferSize, + hipsolverXsygvd_bufferSize, hipsolverXsygvdx_bufferSize, hipsolverXsygvj_bufferSize, hipsolverXsytrd_bufferSize, hipsolverXsytrf_bufferSize + +### Known issues +### Upcoming changes + + +## hipSOLVER 2.5.0 for ROCm 6.5.0 + +### Upcoming changes + +* With the rocSOLVER backend, the bufferSize methods are currently outputting lwork such that the required workspace +size (in bytes) is lwork. In ROCm 7.0 this will change to make the rocSOLVER backend consistent with cuSOLVER. The +changed bufferSize methods will then return lwork such that the required workspace size (in bytes) is sizeof(T) * lwork, +where T is the used precision. This change will break ABI backward compatibility. + + ## hipSOLVER 2.4.0 for ROCm 6.4.0 ### Added diff --git a/clients/include/testing_gesvd.hpp b/clients/include/testing_gesvd.hpp index 2041523c..48a7ed3e 100644 --- a/clients/include/testing_gesvd.hpp +++ b/clients/include/testing_gesvd.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -322,7 +322,8 @@ void testing_gesvd_bad_arg() // int size_W; // hipsolver_gesvd_bufferSize(API, handle, left_svect, right_svect, m, n, dA.data(), lda, &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -350,7 +351,8 @@ void testing_gesvd_bad_arg() int size_W; hipsolver_gesvd_bufferSize( API, handle, left_svect, right_svect, m, n, dA.data(), lda, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -1061,11 +1063,12 @@ void testing_gesvd(Arguments& argus) int size_W, w1, w2; hipsolver_gesvd_bufferSize(API, handle, leftv, rightv, m, n, (T*)nullptr, lda, &w1); hipsolver_gesvd_bufferSize(API, handle, leftvT, rightvT, mT, nT, (T*)nullptr, lda, &w2); - size_W = max(w1, w2); + size_W = max(w1, w2); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -1089,7 +1092,7 @@ void testing_gesvd(Arguments& argus) device_strided_batch_vector dinfo(1, 1, 1, bc); device_strided_batch_vector dVT(size_VT, 1, stVT, bc); device_strided_batch_vector dUT(size_UT, 1, stUT, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_VT) CHECK_HIP_ERROR(dVT.memcheck()); if(size_UT) diff --git a/clients/include/testing_gesvda.hpp b/clients/include/testing_gesvda.hpp index a149e4a2..00ee1c23 100644 --- a/clients/include/testing_gesvda.hpp +++ b/clients/include/testing_gesvda.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -281,7 +281,8 @@ void testing_gesvda_bad_arg() // stV, // &size_W, // bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -344,7 +345,8 @@ void testing_gesvda_bad_arg() stV, &size_W, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -846,10 +848,11 @@ void testing_gesvda(Arguments& argus) stV, &size_W, bc); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -870,7 +873,7 @@ void testing_gesvda(Arguments& argus) device_strided_batch_vector dV(size_V, 1, stV, bc); device_strided_batch_vector dU(size_U, 1, stU, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_S) CHECK_HIP_ERROR(dS.memcheck()); if(size_V) diff --git a/clients/include/testing_gesvdj.hpp b/clients/include/testing_gesvdj.hpp index 06381a60..e6f37d65 100644 --- a/clients/include/testing_gesvdj.hpp +++ b/clients/include/testing_gesvdj.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -278,7 +278,8 @@ void testing_gesvdj_bad_arg() // &size_W, // params, // bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -337,7 +338,8 @@ void testing_gesvdj_bad_arg() &size_W, params, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -888,10 +890,11 @@ void testing_gesvdj(Arguments& argus) &size_W, params, bc); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -912,7 +915,7 @@ void testing_gesvdj(Arguments& argus) device_strided_batch_vector dV(size_V, 1, stV, bc); device_strided_batch_vector dU(size_U, 1, stU, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_S) CHECK_HIP_ERROR(dS.memcheck()); if(size_V) diff --git a/clients/include/testing_orgbr_ungbr.hpp b/clients/include/testing_orgbr_ungbr.hpp index f82ccdfd..7bebb76f 100644 --- a/clients/include/testing_orgbr_ungbr.hpp +++ b/clients/include/testing_orgbr_ungbr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -88,7 +88,8 @@ void testing_orgbr_ungbr_bad_arg() int size_W; hipsolver_orgbr_ungbr_bufferSize( API, handle, side, m, n, k, dA.data(), lda, dIpiv.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -376,10 +377,11 @@ void testing_orgbr_ungbr(Arguments& argus) int size_W; hipsolver_orgbr_ungbr_bufferSize( API, handle, side, m, n, k, (T*)nullptr, lda, (T*)nullptr, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -392,7 +394,7 @@ void testing_orgbr_ungbr(Arguments& argus) device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_orgqr_ungqr.hpp b/clients/include/testing_orgqr_ungqr.hpp index f3ff28ab..cea417a8 100644 --- a/clients/include/testing_orgqr_ungqr.hpp +++ b/clients/include/testing_orgqr_ungqr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -79,7 +79,8 @@ void testing_orgqr_ungqr_bad_arg() int size_W; hipsolver_orgqr_ungqr_bufferSize(API, handle, m, n, k, dA.data(), lda, dIpiv.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -280,10 +281,11 @@ void testing_orgqr_ungqr(Arguments& argus) // memory size query is necessary int size_W; hipsolver_orgqr_ungqr_bufferSize(API, handle, m, n, k, (T*)nullptr, lda, (T*)nullptr, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -296,7 +298,7 @@ void testing_orgqr_ungqr(Arguments& argus) device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_orgtr_ungtr.hpp b/clients/include/testing_orgtr_ungtr.hpp index fed1c860..6b3551f4 100644 --- a/clients/include/testing_orgtr_ungtr.hpp +++ b/clients/include/testing_orgtr_ungtr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -80,7 +80,8 @@ void testing_orgtr_ungtr_bad_arg() int size_W; hipsolver_orgtr_ungtr_bufferSize(API, handle, uplo, n, dA.data(), lda, dIpiv.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -286,10 +287,11 @@ void testing_orgtr_ungtr(Arguments& argus) // memory size query is necessary int size_W; hipsolver_orgtr_ungtr_bufferSize(API, handle, uplo, n, (T*)nullptr, lda, (T*)nullptr, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -302,7 +304,7 @@ void testing_orgtr_ungtr(Arguments& argus) device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_ormqr_unmqr.hpp b/clients/include/testing_ormqr_unmqr.hpp index 91177c41..f5eac415 100644 --- a/clients/include/testing_ormqr_unmqr.hpp +++ b/clients/include/testing_ormqr_unmqr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -186,7 +186,8 @@ void testing_ormqr_unmqr_bad_arg() int size_W; hipsolver_ormqr_unmqr_bufferSize( API, handle, side, trans, m, n, k, dA.data(), lda, dIpiv.data(), dC.data(), ldc, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -528,10 +529,11 @@ void testing_ormqr_unmqr(Arguments& argus) (T*)nullptr, ldc, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -546,7 +548,7 @@ void testing_ormqr_unmqr(Arguments& argus) device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_ormtr_unmtr.hpp b/clients/include/testing_ormtr_unmtr.hpp index 2b76c9e6..5c9025f5 100644 --- a/clients/include/testing_ormtr_unmtr.hpp +++ b/clients/include/testing_ormtr_unmtr.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -237,7 +237,8 @@ void testing_ormtr_unmtr_bad_arg() dC.data(), ldc, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -582,10 +583,11 @@ void testing_ormtr_unmtr(Arguments& argus) (T*)nullptr, ldc, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -600,7 +602,7 @@ void testing_ormtr_unmtr(Arguments& argus) device_strided_batch_vector dIpiv(size_P, 1, size_P, 1); device_strided_batch_vector dA(size_A, 1, size_A, 1); device_strided_batch_vector dInfo(1, 1, 1, 1); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) diff --git a/clients/include/testing_syevd_heevd.hpp b/clients/include/testing_syevd_heevd.hpp index 091a13b5..f09569ac 100644 --- a/clients/include/testing_syevd_heevd.hpp +++ b/clients/include/testing_syevd_heevd.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -123,7 +123,8 @@ void testing_syevd_heevd_bad_arg() // int size_W; // hipsolver_syevd_heevd_bufferSize( // API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -155,7 +156,8 @@ void testing_syevd_heevd_bad_arg() int size_W; hipsolver_syevd_heevd_bufferSize( API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -579,10 +581,11 @@ void testing_syevd_heevd(Arguments& argus) int size_W; hipsolver_syevd_heevd_bufferSize( API, handle, evect, uplo, n, (T*)nullptr, lda, (S*)nullptr, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -595,7 +598,7 @@ void testing_syevd_heevd(Arguments& argus) // device device_strided_batch_vector dD(size_D, 1, stD, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_D) CHECK_HIP_ERROR(dD.memcheck()); CHECK_HIP_ERROR(dinfo.memcheck()); diff --git a/clients/include/testing_syevdx_heevdx.hpp b/clients/include/testing_syevdx_heevdx.hpp index a3fe44cf..0bc0e25a 100644 --- a/clients/include/testing_syevdx_heevdx.hpp +++ b/clients/include/testing_syevdx_heevdx.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -271,7 +271,8 @@ void testing_syevdx_heevdx_bad_arg() // hNev.data(), // dW.data(), // &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -323,7 +324,8 @@ void testing_syevdx_heevdx_bad_arg() hNev.data(), dW.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -844,10 +846,11 @@ void testing_syevdx_heevdx(Arguments& argus) (int*)nullptr, (S*)nullptr, &size_Work); + size_t bytes_Work = sizeof(T) * size_Work; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_Work); + rocsolver_bench_inform(inform_mem_query, bytes_Work); return; } @@ -862,7 +865,8 @@ void testing_syevdx_heevdx(Arguments& argus) // device device_strided_batch_vector dW(size_W, 1, stW, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_Work, 1, size_Work, 1); // size_W accounts for bc + device_strided_batch_vector dWork( + bytes_Work, 1, bytes_Work, 1); // bytes_Work accounts for bc if(size_W) CHECK_HIP_ERROR(dW.memcheck()); CHECK_HIP_ERROR(dinfo.memcheck()); diff --git a/clients/include/testing_syevj_heevj.hpp b/clients/include/testing_syevj_heevj.hpp index b8cfd292..61ae9e79 100644 --- a/clients/include/testing_syevj_heevj.hpp +++ b/clients/include/testing_syevj_heevj.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -181,7 +181,8 @@ void testing_syevj_heevj_bad_arg() // int size_W; // hipsolver_syevj_heevj_bufferSize( // API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W, params, bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -214,7 +215,8 @@ void testing_syevj_heevj_bad_arg() int size_W; hipsolver_syevj_heevj_bufferSize( API, STRIDED, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W, params, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -694,10 +696,11 @@ void testing_syevj_heevj(Arguments& argus) int size_W; hipsolver_syevj_heevj_bufferSize( API, STRIDED, handle, evect, uplo, n, (T*)nullptr, lda, (S*)nullptr, &size_W, params, bc); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -712,7 +715,7 @@ void testing_syevj_heevj(Arguments& argus) // device device_strided_batch_vector dD(size_D, 1, stD, bc); device_strided_batch_vector dinfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_D) CHECK_HIP_ERROR(dD.memcheck()); CHECK_HIP_ERROR(dinfo.memcheck()); diff --git a/clients/include/testing_sygvd_hegvd.hpp b/clients/include/testing_sygvd_hegvd.hpp index 4fcec2df..07bec8f1 100644 --- a/clients/include/testing_sygvd_hegvd.hpp +++ b/clients/include/testing_sygvd_hegvd.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -248,7 +248,8 @@ void testing_sygvd_hegvd_bad_arg() // ldb, // dD.data(), // &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -286,7 +287,8 @@ void testing_sygvd_hegvd_bad_arg() int size_W; hipsolver_sygvd_hegvd_bufferSize( API, handle, itype, evect, uplo, n, dA.data(), lda, dB.data(), ldb, dD.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -858,10 +860,11 @@ void testing_sygvd_hegvd(Arguments& argus) ldb, (S*)nullptr, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -879,7 +882,7 @@ void testing_sygvd_hegvd(Arguments& argus) // device_batch_vector dB(size_B, 1, bc); // device_strided_batch_vector dD(size_D, 1, stD, bc); // device_strided_batch_vector dInfo(1, 1, 1, bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc // if(size_A) // CHECK_HIP_ERROR(dA.memcheck()); // if(size_B) @@ -963,7 +966,7 @@ void testing_sygvd_hegvd(Arguments& argus) device_strided_batch_vector dB(size_B, 1, stB, bc); device_strided_batch_vector dD(size_D, 1, stD, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_B) diff --git a/clients/include/testing_sygvdx_hegvdx.hpp b/clients/include/testing_sygvdx_hegvdx.hpp index 4e7301c3..b73be597 100644 --- a/clients/include/testing_sygvdx_hegvdx.hpp +++ b/clients/include/testing_sygvdx_hegvdx.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -372,7 +372,8 @@ void testing_sygvdx_hegvdx_bad_arg() // hNev.data(), // dW.data(), // &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -437,7 +438,8 @@ void testing_sygvdx_hegvdx_bad_arg() hNev.data(), dW.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -1170,10 +1172,11 @@ void testing_sygvdx_hegvdx(Arguments& argus) (int*)nullptr, (S*)nullptr, &size_Work); + size_t bytes_Work = sizeof(T) * size_Work; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_Work); + rocsolver_bench_inform(inform_mem_query, bytes_Work); return; } @@ -1188,7 +1191,8 @@ void testing_sygvdx_hegvdx(Arguments& argus) // device device_strided_batch_vector dW(size_W, 1, stW, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_Work, 1, size_Work, 1); // size_W accounts for bc + device_strided_batch_vector dWork( + bytes_Work, 1, bytes_Work, 1); // bytes_Work accounts for bc if(size_W) CHECK_HIP_ERROR(dW.memcheck()); CHECK_HIP_ERROR(dInfo.memcheck()); diff --git a/clients/include/testing_sygvj_hegvj.hpp b/clients/include/testing_sygvj_hegvj.hpp index cbfd60d3..32f9f637 100644 --- a/clients/include/testing_sygvj_hegvj.hpp +++ b/clients/include/testing_sygvj_hegvj.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -259,7 +259,8 @@ void testing_sygvj_hegvj_bad_arg() // dD.data(), // &size_W, // params); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -309,7 +310,8 @@ void testing_sygvj_hegvj_bad_arg() dD.data(), &size_W, params); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -919,10 +921,11 @@ void testing_sygvj_hegvj(Arguments& argus) (S*)nullptr, &size_W, params); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -937,7 +940,7 @@ void testing_sygvj_hegvj(Arguments& argus) // device device_strided_batch_vector dD(size_D, 1, stD, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_D) CHECK_HIP_ERROR(dD.memcheck()); CHECK_HIP_ERROR(dInfo.memcheck()); diff --git a/clients/include/testing_sytrd_hetrd.hpp b/clients/include/testing_sytrd_hetrd.hpp index fbd46e7b..83b5ed95 100644 --- a/clients/include/testing_sytrd_hetrd.hpp +++ b/clients/include/testing_sytrd_hetrd.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -211,7 +211,8 @@ void testing_sytrd_hetrd_bad_arg() // int size_W; // hipsolver_sytrd_hetrd_bufferSize( // API, handle, uplo, n, dA.data(), lda, dD.data(), dE.data(), dTau.data(), &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -250,7 +251,8 @@ void testing_sytrd_hetrd_bad_arg() int size_W; hipsolver_sytrd_hetrd_bufferSize( API, handle, uplo, n, dA.data(), lda, dD.data(), dE.data(), dTau.data(), &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -768,10 +770,11 @@ void testing_sytrd_hetrd(Arguments& argus) int size_W; hipsolver_sytrd_hetrd_bufferSize( API, handle, uplo, n, (T*)nullptr, lda, (S*)nullptr, (S*)nullptr, (T*)nullptr, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -787,7 +790,7 @@ void testing_sytrd_hetrd(Arguments& argus) device_strided_batch_vector dE(size_E, 1, stE, bc); device_strided_batch_vector dTau(size_tau, 1, stP, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_D) CHECK_HIP_ERROR(dD.memcheck()); if(size_E) diff --git a/clients/include/testing_sytrf.hpp b/clients/include/testing_sytrf.hpp index 6705a238..19cf6f84 100644 --- a/clients/include/testing_sytrf.hpp +++ b/clients/include/testing_sytrf.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -101,7 +101,8 @@ void testing_sytrf_bad_arg() // int size_W; // hipsolver_sytrf_bufferSize(API, handle, n, dA.data(), lda, &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); + // size_t bytes_W = sizeof(T) * size_W; + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -131,7 +132,8 @@ void testing_sytrf_bad_arg() int size_W; hipsolver_sytrf_bufferSize(API, handle, n, dA.data(), lda, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); + size_t bytes_W = sizeof(T) * size_W; + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -509,10 +511,11 @@ void testing_sytrf(Arguments& argus) // memory size query is necessary int size_W; hipsolver_sytrf_bufferSize(API, handle, n, (T*)nullptr, lda, &size_W); + size_t bytes_W = sizeof(T) * size_W; if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, bytes_W); return; } @@ -528,7 +531,7 @@ void testing_sytrf(Arguments& argus) // device_batch_vector dA(size_A, 1, bc); // device_strided_batch_vector dIpiv(size_P, 1, stP, bc); // device_strided_batch_vector dInfo(1, 1, 1, bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc // if(size_A) // CHECK_HIP_ERROR(dA.memcheck()); // CHECK_HIP_ERROR(dInfo.memcheck()); @@ -594,7 +597,7 @@ void testing_sytrf(Arguments& argus) device_strided_batch_vector dA(size_A, 1, stA, bc); device_strided_batch_vector dIpiv(size_P, 1, stP, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // bytes_W accounts for bc if(size_A) CHECK_HIP_ERROR(dA.memcheck()); CHECK_HIP_ERROR(dInfo.memcheck()); diff --git a/library/CMakeLists.txt b/library/CMakeLists.txt index 522201da..b7ffbf7d 100644 --- a/library/CMakeLists.txt +++ b/library/CMakeLists.txt @@ -1,5 +1,5 @@ # ######################################################################## -# Copyright (C) 2016-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2016-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/library/src/amd_detail/hipsolver.cpp b/library/src/amd_detail/hipsolver.cpp index dae75071..8d75f8f0 100644 --- a/library/src/amd_detail/hipsolver.cpp +++ b/library/src/amd_detail/hipsolver.cpp @@ -1098,6 +1098,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1142,6 +1144,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1186,6 +1190,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1230,6 +1236,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1257,12 +1265,16 @@ hipsolverStatus_t hipsolverSorgbr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1289,12 +1301,16 @@ hipsolverStatus_t hipsolverDorgbr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1321,12 +1337,16 @@ hipsolverStatus_t hipsolverCungbr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverCungbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1359,12 +1379,16 @@ hipsolverStatus_t hipsolverZungbr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZungbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1401,6 +1425,8 @@ try rocsolver_sorgqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1431,6 +1457,8 @@ try rocsolver_dorgqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1467,6 +1495,8 @@ try rocsolver_cungqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1503,6 +1533,8 @@ try rocsolver_zungqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1529,12 +1561,16 @@ hipsolverStatus_t hipsolverSorgqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1560,12 +1596,16 @@ hipsolverStatus_t hipsolverDorgqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1591,12 +1631,16 @@ hipsolverStatus_t hipsolverCungqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverCungqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1627,12 +1671,16 @@ hipsolverStatus_t hipsolverZungqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZungqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1673,6 +1721,8 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1708,6 +1758,8 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1743,6 +1795,8 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1778,6 +1832,8 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -1803,12 +1859,16 @@ hipsolverStatus_t hipsolverSorgtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1833,12 +1893,16 @@ hipsolverStatus_t hipsolverDorgtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1863,12 +1927,16 @@ hipsolverStatus_t hipsolverCungtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverCungtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1897,12 +1965,16 @@ hipsolverStatus_t hipsolverZungtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZungtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -1957,6 +2029,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2007,6 +2081,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2057,6 +2133,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2107,6 +2185,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2137,12 +2217,16 @@ hipsolverStatus_t hipsolverSormqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSormqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2181,12 +2265,16 @@ hipsolverStatus_t hipsolverDormqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDormqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2225,12 +2313,16 @@ hipsolverStatus_t hipsolverCunmqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCunmqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2269,12 +2361,16 @@ hipsolverStatus_t hipsolverZunmqr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZunmqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2334,6 +2430,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2384,6 +2482,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2434,6 +2534,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2484,6 +2586,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -2514,12 +2618,16 @@ hipsolverStatus_t hipsolverSormtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSormtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2558,12 +2666,16 @@ hipsolverStatus_t hipsolverDormtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDormtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2602,12 +2714,16 @@ hipsolverStatus_t hipsolverCunmtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCunmtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -2646,12 +2762,16 @@ hipsolverStatus_t hipsolverZunmtr(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZunmtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -4019,6 +4139,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4071,6 +4193,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4123,6 +4247,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4175,6 +4301,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4216,13 +4344,15 @@ try work = rwork + std::min(m, n); } - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) { @@ -4282,13 +4412,15 @@ try work = rwork + std::min(m, n); } - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) { @@ -4348,13 +4480,15 @@ try work = (hipFloatComplex*)(rwork + std::min(m, n)); } - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) { @@ -4414,13 +4548,15 @@ try work = (hipDoubleComplex*)(rwork + std::min(m, n)); } - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) { @@ -4506,6 +4642,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4572,6 +4710,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4638,6 +4778,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4704,6 +4846,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -4742,12 +4886,16 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -4805,12 +4953,16 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -4868,12 +5020,16 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -4931,12 +5087,16 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5028,6 +5188,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5099,6 +5261,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5170,6 +5334,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5241,6 +5407,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5279,7 +5447,10 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSgesvdjBatched_bufferSize((rocblas_handle)handle, @@ -5296,7 +5467,8 @@ try &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5359,7 +5531,10 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDgesvdjBatched_bufferSize((rocblas_handle)handle, @@ -5376,7 +5551,8 @@ try &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5439,7 +5615,10 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCgesvdjBatched_bufferSize((rocblas_handle)handle, @@ -5456,7 +5635,8 @@ try &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5519,7 +5699,10 @@ try // prepare workspace if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZgesvdjBatched_bufferSize((rocblas_handle)handle, @@ -5536,7 +5719,8 @@ try &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverGesvdjInfo* params = (hipsolverGesvdjInfo*)info; @@ -5644,6 +5828,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5726,6 +5912,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5808,6 +5996,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5889,6 +6079,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -5943,7 +6135,8 @@ try if(std::min(m, n) * batch_count > 0) work = (float*)(ifail + std::min(m, n) * batch_count); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { @@ -5965,7 +6158,8 @@ try strideV, &lwork, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(int) * batch_count, @@ -6051,7 +6245,8 @@ try if(std::min(m, n) * batch_count > 0) work = (double*)(ifail + std::min(m, n) * batch_count); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { @@ -6073,7 +6268,8 @@ try strideV, &lwork, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(int) * batch_count, @@ -6159,7 +6355,8 @@ try if(std::min(m, n) * batch_count > 0) work = (hipFloatComplex*)(ifail + std::min(m, n) * batch_count); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { @@ -6181,7 +6378,8 @@ try strideV, &lwork, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(int) * batch_count, @@ -6267,7 +6465,8 @@ try if(std::min(m, n) * batch_count > 0) work = (hipDoubleComplex*)(ifail + std::min(m, n) * batch_count); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { @@ -6289,7 +6488,8 @@ try strideV, &lwork, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(int) * batch_count, @@ -8421,6 +8621,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8473,6 +8675,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8525,6 +8729,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8577,6 +8783,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8611,13 +8819,15 @@ try if(n > 0) work = E + n; - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSsyevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); if(!mem) @@ -8661,13 +8871,15 @@ try if(n > 0) work = E + n; - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDsyevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); if(!mem) @@ -8711,13 +8923,15 @@ try if(n > 0) work = (hipFloatComplex*)(E + n); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCheevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); if(!mem) @@ -8761,13 +8975,15 @@ try if(n > 0) work = (hipDoubleComplex*)(E + n); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZheevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); if(!mem) @@ -8834,6 +9050,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8890,6 +9108,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -8946,6 +9166,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9002,6 +9224,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9034,12 +9258,16 @@ hipsolverStatus_t hipsolverSsyevdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -9083,12 +9311,16 @@ hipsolverStatus_t hipsolverDsyevdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -9132,12 +9364,16 @@ hipsolverStatus_t hipsolverCheevdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -9181,12 +9417,16 @@ hipsolverStatus_t hipsolverZheevdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -9251,6 +9491,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9303,6 +9545,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9355,6 +9599,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9407,6 +9653,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9439,12 +9687,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9490,12 +9742,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9541,12 +9797,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9592,12 +9852,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9668,6 +9932,8 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9724,6 +9990,8 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9780,6 +10048,8 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9836,6 +10106,8 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -9869,12 +10141,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9925,12 +10201,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -9981,12 +10261,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -10037,12 +10321,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -10119,6 +10407,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10177,6 +10467,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10235,6 +10527,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10293,6 +10587,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10330,13 +10626,15 @@ try if(n > 0) work = E + n; - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSsygvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); if(!mem) @@ -10386,13 +10684,15 @@ try if(n > 0) work = E + n; - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDsygvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); if(!mem) @@ -10442,13 +10742,15 @@ try if(n > 0) work = (hipFloatComplex*)(E + n); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverChegvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); if(!mem) @@ -10498,13 +10800,15 @@ try if(n > 0) work = (hipDoubleComplex*)(E + n); - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZhegvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); if(!mem) @@ -10580,6 +10884,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10642,6 +10948,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10704,6 +11012,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10766,6 +11076,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -10801,7 +11113,10 @@ hipsolverStatus_t hipsolverSsygvdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsygvdx_bufferSize((rocblas_handle)handle, @@ -10821,7 +11136,8 @@ try nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -10871,7 +11187,10 @@ hipsolverStatus_t hipsolverDsygvdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsygvdx_bufferSize((rocblas_handle)handle, @@ -10891,7 +11210,8 @@ try nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -10941,7 +11261,10 @@ hipsolverStatus_t hipsolverChegvdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverChegvdx_bufferSize((rocblas_handle)handle, @@ -10961,7 +11284,8 @@ try nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -11011,7 +11335,10 @@ hipsolverStatus_t hipsolverZhegvdx(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZhegvdx_bufferSize((rocblas_handle)handle, @@ -11031,7 +11358,8 @@ try nev, W, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status( @@ -11104,6 +11432,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11161,6 +11491,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11218,6 +11550,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11275,6 +11609,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11310,13 +11646,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverSsygvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - CHECK_ROCBLAS_ERROR( - hipsolverManageWorkspace((rocblas_handle)handle, lwork + sizeof(float) * n)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -11367,13 +11706,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverDsygvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - CHECK_ROCBLAS_ERROR( - hipsolverManageWorkspace((rocblas_handle)handle, lwork + sizeof(float) * n)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -11424,13 +11766,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverChegvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - CHECK_ROCBLAS_ERROR( - hipsolverManageWorkspace((rocblas_handle)handle, lwork + sizeof(float) * n)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -11481,13 +11826,16 @@ try return HIPSOLVER_STATUS_INVALID_VALUE; if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR(hipsolverZhegvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - CHECK_ROCBLAS_ERROR( - hipsolverManageWorkspace((rocblas_handle)handle, lwork + sizeof(float) * n)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } hipsolverSyevjInfo* params = (hipsolverSyevjInfo*)info; @@ -11548,6 +11896,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11592,6 +11942,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11636,6 +11988,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11680,6 +12034,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11707,12 +12063,16 @@ hipsolverStatus_t hipsolverSsytrd(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSsytrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -11739,12 +12099,16 @@ hipsolverStatus_t hipsolverDsytrd(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDsytrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -11771,12 +12135,16 @@ hipsolverStatus_t hipsolverChetrd(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverChetrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -11809,12 +12177,16 @@ hipsolverStatus_t hipsolverZhetrd(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZhetrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); @@ -11851,6 +12223,8 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(float); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11881,6 +12255,8 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(double); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11911,6 +12287,8 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_float_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11941,6 +12319,8 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); + sz /= sizeof(rocblas_double_complex); + if(status != HIPSOLVER_STATUS_SUCCESS) return status; if(sz > INT_MAX) @@ -11966,12 +12346,16 @@ hipsolverStatus_t hipsolverSsytrf(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverSsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(float) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status(rocsolver_ssytrf( @@ -11994,12 +12378,16 @@ hipsolverStatus_t hipsolverDsytrf(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverDsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(double) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status(rocsolver_dsytrf( @@ -12022,12 +12410,16 @@ hipsolverStatus_t hipsolverCsytrf(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverCsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_float_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status(rocsolver_csytrf((rocblas_handle)handle, @@ -12055,12 +12447,16 @@ hipsolverStatus_t hipsolverZsytrf(hipsolverHandle_t handle, try { if(work && lwork) - CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, lwork)); + { + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); + } else { CHECK_HIPSOLVER_ERROR( hipsolverZsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lwork)); + size_t sz = sizeof(rocblas_double_complex) * lwork; + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } return hipsolver::rocblas2hip_status(rocsolver_zsytrf((rocblas_handle)handle, From 28ea8627d216c07212a4d9ed618b003a844164c7 Mon Sep 17 00:00:00 2001 From: Troy Alderson <58866654+tfalders@users.noreply.github.com> Date: Wed, 28 May 2025 13:18:05 -0600 Subject: [PATCH 2/2] Use HIPSOLVER_BUFFERSIZE_RETURN_BYTES to restore original bufferSize behavior (#398) * Use HIPSOLVER_BUFFERSIZE_RETURN_BYTES to restore original behavior * Updated changelog * Check HIPSOLVER_BUFFERSIZE_RETURN_BYTES in tests * Addressed review comments * Updated changelog (cherry picked from commit 4636922e58d8450ebbb9376ec3a8c2f0a087fbba) --- CHANGELOG.md | 13 +- clients/include/testing_gesvd.hpp | 11 +- clients/include/testing_gesvda.hpp | 9 +- clients/include/testing_gesvdj.hpp | 9 +- clients/include/testing_orgbr_ungbr.hpp | 6 +- clients/include/testing_orgqr_ungqr.hpp | 6 +- clients/include/testing_orgtr_ungtr.hpp | 6 +- clients/include/testing_ormqr_unmqr.hpp | 6 +- clients/include/testing_ormtr_unmtr.hpp | 6 +- clients/include/testing_syevd_heevd.hpp | 9 +- clients/include/testing_syevdx_heevdx.hpp | 10 +- clients/include/testing_syevj_heevj.hpp | 9 +- clients/include/testing_sygvd_hegvd.hpp | 9 +- clients/include/testing_sygvdx_hegvdx.hpp | 10 +- clients/include/testing_sygvj_hegvj.hpp | 9 +- clients/include/testing_sytrd_hetrd.hpp | 9 +- clients/include/testing_sytrf.hpp | 9 +- library/src/amd_detail/hipsolver.cpp | 794 ++++++++++++++++------ 18 files changed, 678 insertions(+), 262 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8922a1c0..2701cea4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,13 +5,19 @@ Full documentation for hipSOLVER is available at the [hipSOLVER Documentation](h ## (Unreleased) hipSOLVER ### Added + +* Added compatibility-only functions + * csrlsvqr + * hipsolverSpCcsrlsvqr, hipsolverSpZcsrlsvqr + ### Changed ### Removed ### Optimized ### Resolved issues * Corrected the value of `lwork` returned by various `bufferSize` functions to be consistent with NVIDIA cuSOLVER. The following functions will - now return `lwork` such that the workspace size (in bytes) is `sizeof(T) * lwork`, rather than `lwork`: + now return `lwork` such that the workspace size (in bytes) is `sizeof(T) * lwork`, rather than `lwork`. To restore the original behavior, set + environment variable `HIPSOLVER_BUFFERSIZE_RETURN_BYTES`. * hipsolverXorgbr_bufferSize, hipsolverXorgqr_bufferSize, hipsolverXorgtr_bufferSize, hipsolverXormqr_bufferSize, hipsolverXormtr_bufferSize, hipsolverXgesvd_bufferSize, hipsolverXgesvdj_bufferSize, hipsolverXgesvdBatched_bufferSize, hipsolverXgesvdaStridedBatched_bufferSize, hipsolverXsyevd_bufferSize, hipsolverXsyevdx_bufferSize, hipsolverXsyevj_bufferSize, hipsolverXsyevjBatched_bufferSize, @@ -41,6 +47,11 @@ where T is the used precision. This change will break ABI backward compatibility ### Upcoming changes * With the rocSOLVER backend, the bufferSize methods are currently outputting `lwork` such that the required workspace size (in bytes) is `lwork`. In ROCm 7.0 this will change to make the rocSOLVER backend consistent with cuSOLVER. The changed bufferSize methods will then return `lwork` so that the required workspace size (in bytes) is `sizeof(T) * lwork`, where T is the precision being used. This change will break ABI backward compatibility. +* With the rocSOLVER backend, the bufferSize methods are currently outputting lwork such that the required workspace + size (in bytes) is lwork. In ROCm 7.0 this will change to make the rocSOLVER backend consistent with cuSOLVER. The + changed bufferSize methods will then return lwork such that the required workspace size (in bytes) is sizeof(T) * lwork, + where T is the used precision. This change will break ABI backward compatibility. + ## hipSOLVER 2.3.0 for ROCm 6.3.0 diff --git a/clients/include/testing_gesvd.hpp b/clients/include/testing_gesvd.hpp index 48a7ed3e..8edffe58 100644 --- a/clients/include/testing_gesvd.hpp +++ b/clients/include/testing_gesvd.hpp @@ -322,7 +322,7 @@ void testing_gesvd_bad_arg() // int size_W; // hipsolver_gesvd_bufferSize(API, handle, left_svect, right_svect, m, n, dA.data(), lda, &size_W); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -351,7 +351,9 @@ void testing_gesvd_bad_arg() int size_W; hipsolver_gesvd_bufferSize( API, handle, left_svect, right_svect, m, n, dA.data(), lda, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -1063,8 +1065,9 @@ void testing_gesvd(Arguments& argus) int size_W, w1, w2; hipsolver_gesvd_bufferSize(API, handle, leftv, rightv, m, n, (T*)nullptr, lda, &w1); hipsolver_gesvd_bufferSize(API, handle, leftvT, rightvT, mT, nT, (T*)nullptr, lda, &w2); - size_W = max(w1, w2); - size_t bytes_W = sizeof(T) * size_W; + size_W = max(w1, w2); + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_gesvda.hpp b/clients/include/testing_gesvda.hpp index 00ee1c23..cc2bc4a8 100644 --- a/clients/include/testing_gesvda.hpp +++ b/clients/include/testing_gesvda.hpp @@ -281,7 +281,7 @@ void testing_gesvda_bad_arg() // stV, // &size_W, // bc); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -345,7 +345,9 @@ void testing_gesvda_bad_arg() stV, &size_W, bc); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -848,7 +850,8 @@ void testing_gesvda(Arguments& argus) stV, &size_W, bc); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_gesvdj.hpp b/clients/include/testing_gesvdj.hpp index e6f37d65..408b7e16 100644 --- a/clients/include/testing_gesvdj.hpp +++ b/clients/include/testing_gesvdj.hpp @@ -278,7 +278,7 @@ void testing_gesvdj_bad_arg() // &size_W, // params, // bc); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -338,7 +338,9 @@ void testing_gesvdj_bad_arg() &size_W, params, bc); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -890,7 +892,8 @@ void testing_gesvdj(Arguments& argus) &size_W, params, bc); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_orgbr_ungbr.hpp b/clients/include/testing_orgbr_ungbr.hpp index 7bebb76f..388c895a 100644 --- a/clients/include/testing_orgbr_ungbr.hpp +++ b/clients/include/testing_orgbr_ungbr.hpp @@ -88,7 +88,8 @@ void testing_orgbr_ungbr_bad_arg() int size_W; hipsolver_orgbr_ungbr_bufferSize( API, handle, side, m, n, k, dA.data(), lda, dIpiv.data(), &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -377,7 +378,8 @@ void testing_orgbr_ungbr(Arguments& argus) int size_W; hipsolver_orgbr_ungbr_bufferSize( API, handle, side, m, n, k, (T*)nullptr, lda, (T*)nullptr, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_orgqr_ungqr.hpp b/clients/include/testing_orgqr_ungqr.hpp index cea417a8..8ebdc7bf 100644 --- a/clients/include/testing_orgqr_ungqr.hpp +++ b/clients/include/testing_orgqr_ungqr.hpp @@ -79,7 +79,8 @@ void testing_orgqr_ungqr_bad_arg() int size_W; hipsolver_orgqr_ungqr_bufferSize(API, handle, m, n, k, dA.data(), lda, dIpiv.data(), &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -281,7 +282,8 @@ void testing_orgqr_ungqr(Arguments& argus) // memory size query is necessary int size_W; hipsolver_orgqr_ungqr_bufferSize(API, handle, m, n, k, (T*)nullptr, lda, (T*)nullptr, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_orgtr_ungtr.hpp b/clients/include/testing_orgtr_ungtr.hpp index 6b3551f4..fe466301 100644 --- a/clients/include/testing_orgtr_ungtr.hpp +++ b/clients/include/testing_orgtr_ungtr.hpp @@ -80,7 +80,8 @@ void testing_orgtr_ungtr_bad_arg() int size_W; hipsolver_orgtr_ungtr_bufferSize(API, handle, uplo, n, dA.data(), lda, dIpiv.data(), &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -287,7 +288,8 @@ void testing_orgtr_ungtr(Arguments& argus) // memory size query is necessary int size_W; hipsolver_orgtr_ungtr_bufferSize(API, handle, uplo, n, (T*)nullptr, lda, (T*)nullptr, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_ormqr_unmqr.hpp b/clients/include/testing_ormqr_unmqr.hpp index f5eac415..47bdbc33 100644 --- a/clients/include/testing_ormqr_unmqr.hpp +++ b/clients/include/testing_ormqr_unmqr.hpp @@ -186,7 +186,8 @@ void testing_ormqr_unmqr_bad_arg() int size_W; hipsolver_ormqr_unmqr_bufferSize( API, handle, side, trans, m, n, k, dA.data(), lda, dIpiv.data(), dC.data(), ldc, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -529,7 +530,8 @@ void testing_ormqr_unmqr(Arguments& argus) (T*)nullptr, ldc, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_ormtr_unmtr.hpp b/clients/include/testing_ormtr_unmtr.hpp index 5c9025f5..8fbe3281 100644 --- a/clients/include/testing_ormtr_unmtr.hpp +++ b/clients/include/testing_ormtr_unmtr.hpp @@ -237,7 +237,8 @@ void testing_ormtr_unmtr_bad_arg() dC.data(), ldc, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -583,7 +584,8 @@ void testing_ormtr_unmtr(Arguments& argus) (T*)nullptr, ldc, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_syevd_heevd.hpp b/clients/include/testing_syevd_heevd.hpp index f09569ac..32668c84 100644 --- a/clients/include/testing_syevd_heevd.hpp +++ b/clients/include/testing_syevd_heevd.hpp @@ -123,7 +123,7 @@ void testing_syevd_heevd_bad_arg() // int size_W; // hipsolver_syevd_heevd_bufferSize( // API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -156,7 +156,9 @@ void testing_syevd_heevd_bad_arg() int size_W; hipsolver_syevd_heevd_bufferSize( API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -581,7 +583,8 @@ void testing_syevd_heevd(Arguments& argus) int size_W; hipsolver_syevd_heevd_bufferSize( API, handle, evect, uplo, n, (T*)nullptr, lda, (S*)nullptr, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_syevdx_heevdx.hpp b/clients/include/testing_syevdx_heevdx.hpp index 0bc0e25a..3645922c 100644 --- a/clients/include/testing_syevdx_heevdx.hpp +++ b/clients/include/testing_syevdx_heevdx.hpp @@ -271,7 +271,7 @@ void testing_syevdx_heevdx_bad_arg() // hNev.data(), // dW.data(), // &size_W); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -324,7 +324,9 @@ void testing_syevdx_heevdx_bad_arg() hNev.data(), dW.data(), &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -846,7 +848,9 @@ void testing_syevdx_heevdx(Arguments& argus) (int*)nullptr, (S*)nullptr, &size_Work); - size_t bytes_Work = sizeof(T) * size_Work; + size_t bytes_Work = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_Work + : sizeof(T) * size_Work; if(argus.mem_query) { diff --git a/clients/include/testing_syevj_heevj.hpp b/clients/include/testing_syevj_heevj.hpp index 61ae9e79..f67c67a1 100644 --- a/clients/include/testing_syevj_heevj.hpp +++ b/clients/include/testing_syevj_heevj.hpp @@ -181,7 +181,7 @@ void testing_syevj_heevj_bad_arg() // int size_W; // hipsolver_syevj_heevj_bufferSize( // API, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W, params, bc); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -215,7 +215,9 @@ void testing_syevj_heevj_bad_arg() int size_W; hipsolver_syevj_heevj_bufferSize( API, STRIDED, handle, evect, uplo, n, dA.data(), lda, dD.data(), &size_W, params, bc); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -696,7 +698,8 @@ void testing_syevj_heevj(Arguments& argus) int size_W; hipsolver_syevj_heevj_bufferSize( API, STRIDED, handle, evect, uplo, n, (T*)nullptr, lda, (S*)nullptr, &size_W, params, bc); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_sygvd_hegvd.hpp b/clients/include/testing_sygvd_hegvd.hpp index 07bec8f1..785fcd34 100644 --- a/clients/include/testing_sygvd_hegvd.hpp +++ b/clients/include/testing_sygvd_hegvd.hpp @@ -248,7 +248,7 @@ void testing_sygvd_hegvd_bad_arg() // ldb, // dD.data(), // &size_W); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -287,7 +287,9 @@ void testing_sygvd_hegvd_bad_arg() int size_W; hipsolver_sygvd_hegvd_bufferSize( API, handle, itype, evect, uplo, n, dA.data(), lda, dB.data(), ldb, dD.data(), &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -860,7 +862,8 @@ void testing_sygvd_hegvd(Arguments& argus) ldb, (S*)nullptr, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_sygvdx_hegvdx.hpp b/clients/include/testing_sygvdx_hegvdx.hpp index b73be597..0a034f90 100644 --- a/clients/include/testing_sygvdx_hegvdx.hpp +++ b/clients/include/testing_sygvdx_hegvdx.hpp @@ -372,7 +372,7 @@ void testing_sygvdx_hegvdx_bad_arg() // hNev.data(), // dW.data(), // &size_W); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -438,7 +438,9 @@ void testing_sygvdx_hegvdx_bad_arg() hNev.data(), dW.data(), &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -1172,7 +1174,9 @@ void testing_sygvdx_hegvdx(Arguments& argus) (int*)nullptr, (S*)nullptr, &size_Work); - size_t bytes_Work = sizeof(T) * size_Work; + size_t bytes_Work = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_Work + : sizeof(T) * size_Work; if(argus.mem_query) { diff --git a/clients/include/testing_sygvj_hegvj.hpp b/clients/include/testing_sygvj_hegvj.hpp index 32f9f637..e25af611 100644 --- a/clients/include/testing_sygvj_hegvj.hpp +++ b/clients/include/testing_sygvj_hegvj.hpp @@ -259,7 +259,7 @@ void testing_sygvj_hegvj_bad_arg() // dD.data(), // &size_W, // params); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -310,7 +310,9 @@ void testing_sygvj_hegvj_bad_arg() dD.data(), &size_W, params); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -921,7 +923,8 @@ void testing_sygvj_hegvj(Arguments& argus) (S*)nullptr, &size_W, params); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_sytrd_hetrd.hpp b/clients/include/testing_sytrd_hetrd.hpp index 83b5ed95..98885c5f 100644 --- a/clients/include/testing_sytrd_hetrd.hpp +++ b/clients/include/testing_sytrd_hetrd.hpp @@ -211,7 +211,7 @@ void testing_sytrd_hetrd_bad_arg() // int size_W; // hipsolver_sytrd_hetrd_bufferSize( // API, handle, uplo, n, dA.data(), lda, dD.data(), dE.data(), dTau.data(), &size_W); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -251,7 +251,9 @@ void testing_sytrd_hetrd_bad_arg() int size_W; hipsolver_sytrd_hetrd_bufferSize( API, handle, uplo, n, dA.data(), lda, dD.data(), dE.data(), dTau.data(), &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -770,7 +772,8 @@ void testing_sytrd_hetrd(Arguments& argus) int size_W; hipsolver_sytrd_hetrd_bufferSize( API, handle, uplo, n, (T*)nullptr, lda, (S*)nullptr, (S*)nullptr, (T*)nullptr, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/clients/include/testing_sytrf.hpp b/clients/include/testing_sytrf.hpp index 19cf6f84..2901b0fa 100644 --- a/clients/include/testing_sytrf.hpp +++ b/clients/include/testing_sytrf.hpp @@ -101,7 +101,7 @@ void testing_sytrf_bad_arg() // int size_W; // hipsolver_sytrf_bufferSize(API, handle, n, dA.data(), lda, &size_W); - // size_t bytes_W = sizeof(T) * size_W; + // size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; // device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); // if(size_W) // CHECK_HIP_ERROR(dWork.memcheck()); @@ -132,7 +132,9 @@ void testing_sytrf_bad_arg() int size_W; hipsolver_sytrf_bufferSize(API, handle, n, dA.data(), lda, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? size_W + : sizeof(T) * size_W; device_strided_batch_vector dWork(bytes_W, 1, bytes_W, 1); if(size_W) CHECK_HIP_ERROR(dWork.memcheck()); @@ -511,7 +513,8 @@ void testing_sytrf(Arguments& argus) // memory size query is necessary int size_W; hipsolver_sytrf_bufferSize(API, handle, n, (T*)nullptr, lda, &size_W); - size_t bytes_W = sizeof(T) * size_W; + size_t bytes_W + = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr ? size_W : sizeof(T) * size_W; if(argus.mem_query) { diff --git a/library/src/amd_detail/hipsolver.cpp b/library/src/amd_detail/hipsolver.cpp index 8d75f8f0..445695fd 100644 --- a/library/src/amd_detail/hipsolver.cpp +++ b/library/src/amd_detail/hipsolver.cpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1098,7 +1098,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1144,7 +1145,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1190,7 +1192,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1236,7 +1239,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1266,14 +1270,18 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1302,14 +1310,18 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1338,14 +1350,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCungbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1380,14 +1396,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZungbr_bufferSize((rocblas_handle)handle, side, m, n, k, A, lda, tau, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1425,7 +1445,8 @@ try rocsolver_sorgqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1457,7 +1478,8 @@ try rocsolver_dorgqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1495,7 +1517,8 @@ try rocsolver_cungqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1533,7 +1556,8 @@ try rocsolver_zungqr((rocblas_handle)handle, m, n, k, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1562,14 +1586,18 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1597,14 +1625,18 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1632,14 +1664,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCungqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1672,14 +1708,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZungqr_bufferSize((rocblas_handle)handle, m, n, k, A, lda, tau, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1721,7 +1761,8 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1758,7 +1799,8 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1795,7 +1837,8 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1832,7 +1875,8 @@ try (rocblas_handle)handle, hipsolver::hip2rocblas_fill(uplo), n, nullptr, lda, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -1860,14 +1904,18 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSorgtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1894,14 +1942,18 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDorgtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1928,14 +1980,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCungtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -1966,14 +2022,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZungtr_bufferSize((rocblas_handle)handle, uplo, n, A, lda, tau, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -2029,7 +2089,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -2081,7 +2142,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -2133,7 +2195,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -2185,7 +2248,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -2218,14 +2282,18 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSormqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -2266,14 +2334,18 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDormqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -2314,14 +2386,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverCunmqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -2362,14 +2438,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZunmqr_bufferSize( (rocblas_handle)handle, side, trans, m, n, k, A, lda, tau, C, ldc, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -2430,7 +2510,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -2482,7 +2563,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -2534,7 +2616,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -2586,7 +2669,8 @@ try ldc)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -2619,14 +2703,18 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSormtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -2667,14 +2755,18 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDormtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -2715,14 +2807,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverCunmtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -2763,14 +2859,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZunmtr_bufferSize( (rocblas_handle)handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -4139,7 +4239,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -4193,7 +4294,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -4247,7 +4349,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -4301,7 +4404,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -4344,14 +4448,18 @@ try work = rwork + std::min(m, n); } - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) @@ -4412,14 +4520,18 @@ try work = rwork + std::min(m, n); } - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) @@ -4480,14 +4592,18 @@ try work = (hipFloatComplex*)(rwork + std::min(m, n)); } - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) @@ -4548,14 +4664,18 @@ try work = (hipDoubleComplex*)(rwork + std::min(m, n)); } - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZgesvd_bufferSize((rocblas_handle)handle, jobu, jobv, m, n, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); if(!rwork && std::min(m, n) > 1) @@ -4642,7 +4762,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -4710,7 +4831,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -4778,7 +4900,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -4846,7 +4969,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -4887,14 +5011,18 @@ try // prepare workspace if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -4954,14 +5082,18 @@ try // prepare workspace if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -5021,14 +5153,18 @@ try // prepare workspace if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverCgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -5088,14 +5224,18 @@ try // prepare workspace if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZgesvdj_bufferSize( (rocblas_handle)handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, info)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -5188,7 +5328,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -5261,7 +5402,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -5334,7 +5476,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -5407,7 +5550,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -5448,7 +5592,9 @@ try // prepare workspace if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -5467,7 +5613,9 @@ try &lwork, info, batch_count)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -5532,7 +5680,9 @@ try // prepare workspace if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -5551,7 +5701,9 @@ try &lwork, info, batch_count)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -5616,7 +5768,9 @@ try // prepare workspace if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -5635,7 +5789,9 @@ try &lwork, info, batch_count)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -5700,7 +5856,9 @@ try // prepare workspace if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -5719,7 +5877,9 @@ try &lwork, info, batch_count)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -5828,7 +5988,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -5912,7 +6073,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -5996,7 +6158,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -6079,7 +6242,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_nsv, size_ifail); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -6135,7 +6299,9 @@ try if(std::min(m, n) * batch_count > 0) work = (float*)(ifail + std::min(m, n) * batch_count); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -6158,7 +6324,9 @@ try strideV, &lwork, batch_count)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, @@ -6245,7 +6413,9 @@ try if(std::min(m, n) * batch_count > 0) work = (double*)(ifail + std::min(m, n) * batch_count); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -6268,7 +6438,9 @@ try strideV, &lwork, batch_count)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, @@ -6355,7 +6527,9 @@ try if(std::min(m, n) * batch_count > 0) work = (hipFloatComplex*)(ifail + std::min(m, n) * batch_count); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -6378,7 +6552,9 @@ try strideV, &lwork, batch_count)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, @@ -6465,7 +6641,9 @@ try if(std::min(m, n) * batch_count > 0) work = (hipDoubleComplex*)(ifail + std::min(m, n) * batch_count); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -6488,7 +6666,9 @@ try strideV, &lwork, batch_count)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, @@ -8621,7 +8801,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -8675,7 +8856,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -8729,7 +8911,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -8783,7 +8966,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -8819,14 +9003,18 @@ try if(n > 0) work = E + n; - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSsyevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); @@ -8871,14 +9059,18 @@ try if(n > 0) work = E + n; - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDsyevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); @@ -8923,14 +9115,18 @@ try if(n > 0) work = (hipFloatComplex*)(E + n); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCheevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); @@ -8975,14 +9171,18 @@ try if(n > 0) work = (hipDoubleComplex*)(E + n); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZheevd_bufferSize((rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); @@ -9050,7 +9250,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9108,7 +9309,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9166,7 +9368,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9224,7 +9427,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9259,14 +9463,18 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -9312,14 +9520,18 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -9365,14 +9577,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -9418,14 +9634,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevdx_bufferSize( (rocblas_handle)handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, nev, W, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -9491,7 +9711,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9545,7 +9766,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9599,7 +9821,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9653,7 +9876,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9688,14 +9912,18 @@ try if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -9743,14 +9971,18 @@ try if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -9798,14 +10030,18 @@ try if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -9853,14 +10089,18 @@ try if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevj_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -9932,7 +10172,8 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -9990,7 +10231,8 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -10048,7 +10290,8 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -10106,7 +10349,8 @@ try batch_count)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -10142,14 +10386,18 @@ try if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSsyevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -10202,14 +10450,18 @@ try if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDsyevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -10262,14 +10514,18 @@ try if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverCheevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -10322,14 +10578,18 @@ try if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZheevjBatched_bufferSize( (rocblas_handle)handle, jobz, uplo, n, A, lda, W, &lwork, info, batch_count)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -10407,7 +10667,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -10467,7 +10728,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -10527,7 +10789,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -10587,7 +10850,8 @@ try rocblas_set_optimal_device_memory_size((rocblas_handle)handle, sz, size_E); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -10626,14 +10890,18 @@ try if(n > 0) work = E + n; - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSsygvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); @@ -10684,14 +10952,18 @@ try if(n > 0) work = E + n; - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDsygvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); @@ -10742,14 +11014,18 @@ try if(n > 0) work = (hipFloatComplex*)(E + n); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverChegvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(float) * n); @@ -10800,14 +11076,18 @@ try if(n > 0) work = (hipDoubleComplex*)(E + n); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZhegvd_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); mem = rocblas_device_malloc((rocblas_handle)handle, sizeof(double) * n); @@ -10884,7 +11164,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -10948,7 +11229,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11012,7 +11294,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11076,7 +11359,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11114,7 +11398,9 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -11136,7 +11422,9 @@ try nev, W, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -11188,7 +11476,9 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -11210,7 +11500,9 @@ try nev, W, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -11262,7 +11554,9 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -11284,7 +11578,9 @@ try nev, W, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -11336,7 +11632,9 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else @@ -11358,7 +11656,9 @@ try nev, W, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -11432,7 +11732,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11491,7 +11792,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11550,7 +11852,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11609,7 +11912,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11647,14 +11951,18 @@ try if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverSsygvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -11707,14 +12015,18 @@ try if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverDsygvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -11767,14 +12079,18 @@ try if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverChegvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -11827,14 +12143,18 @@ try if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR(hipsolverZhegvj_bufferSize( (rocblas_handle)handle, itype, jobz, uplo, n, A, lda, B, ldb, W, &lwork, info)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -11896,7 +12216,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11942,7 +12263,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -11988,7 +12310,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -12034,7 +12357,8 @@ try nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -12064,14 +12388,18 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSsytrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -12100,14 +12428,18 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDsytrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -12136,14 +12468,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverChetrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -12178,14 +12514,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZhetrd_bufferSize((rocblas_handle)handle, uplo, n, A, lda, D, E, tau, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -12223,7 +12563,8 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(float); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(float); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -12255,7 +12596,8 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(double); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(double); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -12287,7 +12629,8 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_float_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_float_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -12319,7 +12662,8 @@ try (rocblas_handle)handle, rocblas_fill_upper, n, nullptr, lda, nullptr, nullptr)); rocblas_stop_device_memory_size_query((rocblas_handle)handle, &sz); - sz /= sizeof(rocblas_double_complex); + if(std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") == nullptr) + sz /= sizeof(rocblas_double_complex); if(status != HIPSOLVER_STATUS_SUCCESS) return status; @@ -12347,14 +12691,18 @@ try { if(work && lwork) { - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverSsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - size_t sz = sizeof(float) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(float) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -12379,14 +12727,18 @@ try { if(work && lwork) { - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverDsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - size_t sz = sizeof(double) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(double) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -12411,14 +12763,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverCsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - size_t sz = sizeof(rocblas_float_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_float_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); } @@ -12448,14 +12804,18 @@ try { if(work && lwork) { - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(rocblas_set_workspace((rocblas_handle)handle, work, sz)); } else { CHECK_HIPSOLVER_ERROR( hipsolverZsytrf_bufferSize((rocblas_handle)handle, n, A, lda, &lwork)); - size_t sz = sizeof(rocblas_double_complex) * lwork; + size_t sz = std::getenv("HIPSOLVER_BUFFERSIZE_RETURN_BYTES") != nullptr + ? lwork + : sizeof(rocblas_double_complex) * lwork; CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, sz)); }