diff --git a/examples/aiscp/aiscp.cpp b/examples/aiscp/aiscp.cpp index 6132be06..1ab300d8 100644 --- a/examples/aiscp/aiscp.cpp +++ b/examples/aiscp/aiscp.cpp @@ -10,9 +10,14 @@ * This _very_basic_ program copies SOURCE to DEST via GPU memory. */ +#ifndef AISCP_CHUNK_SIZE +#define AISCP_CHUNK_SIZE 0x7ffff000lu +#endif + #include #include +#include #include #include #include @@ -70,6 +75,16 @@ close_file(const char *path, int fd, hipFileHandle_t handle) return 0; } +/// @brief Round value to the next multiple of align. Align _must_ be a power of 2. +/// @param value The value to round up. +/// @param align Value will be round up to a multiple of align +/// @return Value rounded up to a multiple of align. +static inline size_t +alignUp(size_t value, size_t align) +{ + return (value + align - 1) & ~(align - 1); +} + int main(int argc, char *argv[]) { @@ -80,7 +95,7 @@ main(int argc, char *argv[]) hipError_t hip_err; int exit_status = EXIT_FAILURE; size_t buffer_size, file_size, block_size; - ssize_t nbytes; + ssize_t ncopy{}, nwrite{}, nread{}, nbytes{}; if (argc != 3) { fprintf(stderr, "Usage: %s SOURCE DEST\n", argv[0]); @@ -106,6 +121,7 @@ main(int argc, char *argv[]) } if (0 == file_size) { + exit_status = EXIT_SUCCESS; goto close_dst; } @@ -113,35 +129,38 @@ main(int argc, char *argv[]) goto close_dst; } - buffer_size = file_size; - // If needed, round buffer_size up to the next multiple of block_size - if (buffer_size & (block_size - 1)) { - buffer_size = (buffer_size + block_size) & ~(block_size - 1); - } - hip_err = hipMalloc(&devbuf, buffer_size); + buffer_size = alignUp(std::min(file_size, AISCP_CHUNK_SIZE), block_size); + hip_err = hipMalloc(&devbuf, buffer_size); if (hipSuccess != hip_err) { fprintf(stderr, "Could not allocate device buffer (%d)", hip_err); goto close_src; } - nbytes = hipFileRead(src_handle, devbuf, buffer_size, 0, 0); - if (nbytes < 0 || file_size != static_cast(nbytes)) { - fprintf(stderr, "Could not read from %s (%zd) (%s)\n", src_path, nbytes, - IS_HIPFILE_ERR(nbytes) ? HIPFILE_ERRSTR(nbytes) : strerror(errno)); - goto free_devbuf; - } + while (static_cast(ncopy) < file_size) { + nread = hipFileRead(src_handle, devbuf, buffer_size, ncopy, 0); + if (nread < 0) { + fprintf(stderr, "Could not read from %s (%zd) (%s)\n", src_path, nread, + IS_HIPFILE_ERR(nread) ? HIPFILE_ERRSTR(nread) : strerror(errno)); + goto free_devbuf; + } - nbytes = hipFileWrite(dst_handle, devbuf, buffer_size, 0, 0); - if (nbytes < 0 || buffer_size != static_cast(nbytes)) { - fprintf(stderr, "Could not write to %s (%zd) (%s)\n", src_path, nbytes, - IS_HIPFILE_ERR(nbytes) ? HIPFILE_ERRSTR(nbytes) : strerror(errno)); - goto free_devbuf; + nwrite = 0; + while (nwrite < nread) { + nbytes = + hipFileWrite(dst_handle, devbuf, alignUp(static_cast(nread - nwrite), block_size), + static_cast(ncopy + nwrite), static_cast(nwrite)); + if (nbytes < 0) { + fprintf(stderr, "Could not write to %s (%zd) (%s)\n", src_path, nbytes, + IS_HIPFILE_ERR(nbytes) ? HIPFILE_ERRSTR(nbytes) : strerror(errno)); + goto free_devbuf; + } + nwrite += nbytes; + } + ncopy += nread; } - if (file_size < buffer_size) { - if (-1 == ftruncate(dst_fd, static_cast(file_size))) { - fprintf(stderr, "Could not truncate %s (%zu) (%s)\n", dst_path, file_size, strerror(errno)); - } + if (-1 == ftruncate(dst_fd, static_cast(file_size))) { + fprintf(stderr, "Could not truncate %s (%zu) (%s)\n", dst_path, file_size, strerror(errno)); } exit_status = EXIT_SUCCESS;