diff --git a/NAMESPACE b/NAMESPACE index 8692482c..ed4d83e7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -56,6 +56,7 @@ export(device) export(devices) export(dtype) export(elt_type) +export(format_array) export(format_buffer) export(is_ready) export(pjrt_buffer) diff --git a/R/RcppExports.R b/R/RcppExports.R index 8389edf7..0fad15ac 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -121,6 +121,10 @@ impl_buffer_print <- function(buffer, max_rows, max_width, max_rows_slice) { invisible(.Call(`_pjrt_impl_buffer_print`, buffer, max_rows, max_width, max_rows_slice)) } +impl_format_array <- function(data, max_rows, max_width, max_rows_slice) { + .Call(`_pjrt_impl_format_array`, data, max_rows, max_width, max_rows_slice) +} + impl_buffer_is_ready <- function(buffer) { .Call(`_pjrt_impl_buffer_is_ready`, buffer) } diff --git a/R/format.R b/R/format.R index 673d2e0d..5f418fe5 100644 --- a/R/format.R +++ b/R/format.R @@ -10,6 +10,30 @@ #' buf <- pjrt_buffer(c(1.5, 2.5, 3.5)) #' format_buffer(buf) #' @export +#' @title Format Array Lines +#' +#' @description +#' Formats an R array (or vector/scalar) into display lines using the same +#' printer as [print.PJRTBuffer()]. +#' +#' @param data An R numeric, integer, or logical array/vector/scalar. +#' @param max_rows Maximum total rows to print (`-1` for unlimited). +#' @param max_width Maximum line width in characters. +#' @param max_rows_slice Maximum rows per 2D slice. +#' +#' @return `character()` Vector of formatted lines. +#' @examples +#' format_array(matrix(1:6, nrow = 2)) +#' @export +format_array <- function( + data, + max_rows = getOption("pjrt.print_max_rows", 30L), + max_width = getOption("pjrt.print_max_width", 85L), + max_rows_slice = getOption("pjrt.print_max_rows_slice", max_rows) +) { + impl_format_array(data, max_rows, max_width, max_rows_slice) +} + format_buffer <- function(buffer) { if (!is_buffer(buffer)) { cli_abort("`buffer` must be a `PJRTBuffer`") diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 9d98cd87..91c51f9b 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -362,6 +362,20 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// impl_format_array +Rcpp::CharacterVector impl_format_array(SEXP data, int max_rows, int max_width, int max_rows_slice); +RcppExport SEXP _pjrt_impl_format_array(SEXP dataSEXP, SEXP max_rowsSEXP, SEXP max_widthSEXP, SEXP max_rows_sliceSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< SEXP >::type data(dataSEXP); + Rcpp::traits::input_parameter< int >::type max_rows(max_rowsSEXP); + Rcpp::traits::input_parameter< int >::type max_width(max_widthSEXP); + Rcpp::traits::input_parameter< int >::type max_rows_slice(max_rows_sliceSEXP); + rcpp_result_gen = Rcpp::wrap(impl_format_array(data, max_rows, max_width, max_rows_slice)); + return rcpp_result_gen; +END_RCPP +} // impl_buffer_is_ready bool impl_buffer_is_ready(Rcpp::XPtr buffer); RcppExport SEXP _pjrt_impl_buffer_is_ready(SEXP bufferSEXP) { @@ -527,6 +541,7 @@ static const R_CallMethodDef CallEntries[] = { {"_pjrt_impl_device_to_string", (DL_FUNC) &_pjrt_impl_device_to_string, 1}, {"_pjrt_impl_device_platform", (DL_FUNC) &_pjrt_impl_device_platform, 1}, {"_pjrt_impl_buffer_print", (DL_FUNC) &_pjrt_impl_buffer_print, 4}, + {"_pjrt_impl_format_array", (DL_FUNC) &_pjrt_impl_format_array, 4}, {"_pjrt_impl_buffer_is_ready", (DL_FUNC) &_pjrt_impl_buffer_is_ready, 1}, {"_pjrt_impl_buffer_await", (DL_FUNC) &_pjrt_impl_buffer_await, 1}, {"_pjrt_impl_host_data_is_ready", (DL_FUNC) &_pjrt_impl_host_data_is_ready, 1}, diff --git a/src/buffer_printer.cpp b/src/buffer_printer.cpp index 55e967d7..868ec60c 100644 --- a/src/buffer_printer.cpp +++ b/src/buffer_printer.cpp @@ -204,23 +204,45 @@ static std::pair build_buffer_lines_subset( return {r, c_end}; } -// Create a contiguous span over the last two dimensions for a given leading -// index +// Extract a 2D slice (last two dimensions) into row-major order. +// For row-major input, the slice is contiguous and simply copied. +// For column-major input, elements are gathered using column-major strides. template -static std::span make_last2_contiguous_span( +static std::vector extract_last2_slice( const std::span &flat, const std::vector &pseudo_dims, - const std::vector &lead_index) { + const std::vector &lead_index, bool row_major) { const int nprint = static_cast(pseudo_dims.size()); - std::vector stride(nprint, 1); - for (int i = nprint - 2; i >= 0; --i) - stride[i] = stride[i + 1] * pseudo_dims[i + 1]; - int64_t base = 0; - for (size_t k = 0; k < lead_index.size(); ++k) - base += lead_index[k] * stride[k]; const int64_t nrows = pseudo_dims[nprint - 2]; const int64_t ncols = pseudo_dims[nprint - 1]; - return std::span(flat.data() + static_cast(base), - static_cast(nrows * ncols)); + const size_t slice_size = static_cast(nrows * ncols); + + if (row_major) { + std::vector stride(nprint, 1); + for (int i = nprint - 2; i >= 0; --i) + stride[i] = stride[i + 1] * pseudo_dims[i + 1]; + int64_t base = 0; + for (size_t k = 0; k < lead_index.size(); ++k) + base += lead_index[k] * stride[k]; + const T *start = flat.data() + static_cast(base); + return std::vector(start, start + slice_size); + } + + // Column-major: gather elements using column-major strides + std::vector strides = dims2strides(pseudo_dims, false); + int64_t base = 0; + for (size_t k = 0; k < lead_index.size(); ++k) + base += lead_index[k] * strides[k]; + int64_t row_stride = strides[nprint - 2]; + int64_t col_stride = strides[nprint - 1]; + + std::vector result(slice_size); + for (int64_t r = 0; r < nrows; ++r) { + for (int64_t c = 0; c < ncols; ++c) { + result[static_cast(r * ncols + c)] = + flat[static_cast(base + r * row_stride + c * col_stride)]; + } + } + return result; } // core printer @@ -231,7 +253,8 @@ static void print_with_formatter_fn(const std::vector &dimensions, int max_width, int max_rows_slice, int rows_left, std::vector &cont, - std::span temp_vec) { + std::span temp_vec, + bool row_major) { const int ndim = dimensions.size(); // pseudo_dims are used so we don't have to treat the 0d and 1d cases special @@ -281,10 +304,10 @@ static void print_with_formatter_fn(const std::vector &dimensions, cont.push_back(hdr.str()); } - // Extract this slice as a span (because of row-major ordering, - // this data is contigous) - std::span slice = - make_last2_contiguous_span(temp_vec, pseudo_dims, lead_index); + // Extract this slice as a row-major ordered vector + std::vector slice_vec = + extract_last2_slice(temp_vec, pseudo_dims, lead_index, row_major); + std::span slice(slice_vec); // now do the actual data printing int64_t rows_to_print = nrows; @@ -407,7 +430,7 @@ static void print_with_formatter_fn(const std::vector &dimensions, std::vector buffer_to_string_lines( const void *data, const std::vector &dimensions, PJRT_Buffer_Type element_type, int max_rows, int max_width, - int max_rows_slice) { + int max_rows_slice, bool row_major) { int64_t numel = dimensions.empty() ? 1 : number_of_elements(dimensions); if (numel == 0) { @@ -425,7 +448,7 @@ std::vector buffer_to_string_lines( std::span temp_span(static_cast(data), static_cast(numel)); print_with_formatter_fn(dimensions, max_width, max_rows_slice, rows_left, - cont, temp_span); + cont, temp_span, row_major); }; auto handle_integer = [&](auto int_tag) { @@ -433,7 +456,7 @@ std::vector buffer_to_string_lines( std::span temp_span(static_cast(data), static_cast(numel)); print_with_formatter_fn(dimensions, max_width, max_rows_slice, rows_left, - cont, temp_span); + cont, temp_span, row_major); }; auto handle_logical = [&]() { @@ -441,7 +464,7 @@ std::vector buffer_to_string_lines( std::span temp_span(static_cast(data), static_cast(numel)); print_with_formatter_fn(dimensions, max_width, max_rows_slice, rows_left, - cont, temp_span); + cont, temp_span, row_major); }; switch (element_type) { diff --git a/src/buffer_printer.h b/src/buffer_printer.h index 232ba948..7d95f1c2 100644 --- a/src/buffer_printer.h +++ b/src/buffer_printer.h @@ -13,4 +13,4 @@ void buffer_print(Rcpp::XPtr buffer, int n = 30, std::vector buffer_to_string_lines( const void* data, const std::vector& dimensions, PJRT_Buffer_Type element_type, int max_rows = 30, int max_width = 85, - int max_rows_slice = 30); + int max_rows_slice = 30, bool row_major = true); diff --git a/src/pjrt.cpp b/src/pjrt.cpp index f77ed9de..06ad0c61 100644 --- a/src/pjrt.cpp +++ b/src/pjrt.cpp @@ -581,6 +581,53 @@ void impl_buffer_print(Rcpp::XPtr buffer, int max_rows, buffer_print(buffer, max_rows, max_width, max_rows_slice); } +// [[Rcpp::export()]] +Rcpp::CharacterVector impl_format_array(SEXP data, int max_rows, int max_width, + int max_rows_slice) { + PJRT_Buffer_Type element_type; + const void *data_ptr; + std::vector logical_data; + + if (Rf_isReal(data)) { + element_type = PJRT_Buffer_Type_F64; + data_ptr = REAL(data); + } else if (Rf_isInteger(data)) { + element_type = PJRT_Buffer_Type_S32; + data_ptr = INTEGER(data); + } else if (Rf_isLogical(data)) { + int n = Rf_length(data); + logical_data.resize(n); + int *lgl = LOGICAL(data); + for (int i = 0; i < n; i++) { + logical_data[i] = static_cast(lgl[i] != 0); + } + element_type = PJRT_Buffer_Type_PRED; + data_ptr = logical_data.data(); + } else { + Rcpp::stop("Unsupported R type for formatting."); + } + + SEXP dim_attr = Rf_getAttrib(data, R_DimSymbol); + std::vector dimensions; + if (!Rf_isNull(dim_attr)) { + int ndim = Rf_length(dim_attr); + int *dim_ptr = INTEGER(dim_attr); + for (int i = 0; i < ndim; i++) { + dimensions.push_back(static_cast(dim_ptr[i])); + } + } else { + int n = Rf_length(data); + if (n != 1) { + dimensions.push_back(static_cast(n)); + } + } + + auto lines = buffer_to_string_lines(data_ptr, dimensions, element_type, + max_rows, max_width, max_rows_slice, + /*row_major=*/false); + return Rcpp::wrap(lines); +} + // Async status functions for buffers and host data // [[Rcpp::export()]]