Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export(device)
export(devices)
export(dtype)
export(elt_type)
export(format_array)
export(format_buffer)
export(is_ready)
export(pjrt_buffer)
Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ impl_buffer_print <- function(buffer, max_rows, max_width, max_rows_slice) {
invisible(.Call(`_pjrt_impl_buffer_print`, buffer, max_rows, max_width, max_rows_slice))
}

impl_format_array <- function(data, max_rows, max_width, max_rows_slice) {
.Call(`_pjrt_impl_format_array`, data, max_rows, max_width, max_rows_slice)
}

impl_buffer_is_ready <- function(buffer) {
.Call(`_pjrt_impl_buffer_is_ready`, buffer)
}
Expand Down
24 changes: 24 additions & 0 deletions R/format.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,30 @@
#' buf <- pjrt_buffer(c(1.5, 2.5, 3.5))
#' format_buffer(buf)
#' @export
#' @title Format Array Lines
#'
#' @description
#' Formats an R array (or vector/scalar) into display lines using the same
#' printer as [print.PJRTBuffer()].
#'
#' @param data An R numeric, integer, or logical array/vector/scalar.
#' @param max_rows Maximum total rows to print (`-1` for unlimited).
#' @param max_width Maximum line width in characters.
#' @param max_rows_slice Maximum rows per 2D slice.
#'
#' @return `character()` Vector of formatted lines.
#' @examples
#' format_array(matrix(1:6, nrow = 2))
#' @export
format_array <- function(
data,
max_rows = getOption("pjrt.print_max_rows", 30L),
max_width = getOption("pjrt.print_max_width", 85L),
max_rows_slice = getOption("pjrt.print_max_rows_slice", max_rows)
) {
impl_format_array(data, max_rows, max_width, max_rows_slice)
}

format_buffer <- function(buffer) {
if (!is_buffer(buffer)) {
cli_abort("`buffer` must be a `PJRTBuffer`")
Expand Down
15 changes: 15 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,20 @@ BEGIN_RCPP
return R_NilValue;
END_RCPP
}
// impl_format_array
Rcpp::CharacterVector impl_format_array(SEXP data, int max_rows, int max_width, int max_rows_slice);
RcppExport SEXP _pjrt_impl_format_array(SEXP dataSEXP, SEXP max_rowsSEXP, SEXP max_widthSEXP, SEXP max_rows_sliceSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< SEXP >::type data(dataSEXP);
Rcpp::traits::input_parameter< int >::type max_rows(max_rowsSEXP);
Rcpp::traits::input_parameter< int >::type max_width(max_widthSEXP);
Rcpp::traits::input_parameter< int >::type max_rows_slice(max_rows_sliceSEXP);
rcpp_result_gen = Rcpp::wrap(impl_format_array(data, max_rows, max_width, max_rows_slice));
return rcpp_result_gen;
END_RCPP
}
// impl_buffer_is_ready
bool impl_buffer_is_ready(Rcpp::XPtr<rpjrt::PJRTBuffer> buffer);
RcppExport SEXP _pjrt_impl_buffer_is_ready(SEXP bufferSEXP) {
Expand Down Expand Up @@ -527,6 +541,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_pjrt_impl_device_to_string", (DL_FUNC) &_pjrt_impl_device_to_string, 1},
{"_pjrt_impl_device_platform", (DL_FUNC) &_pjrt_impl_device_platform, 1},
{"_pjrt_impl_buffer_print", (DL_FUNC) &_pjrt_impl_buffer_print, 4},
{"_pjrt_impl_format_array", (DL_FUNC) &_pjrt_impl_format_array, 4},
{"_pjrt_impl_buffer_is_ready", (DL_FUNC) &_pjrt_impl_buffer_is_ready, 1},
{"_pjrt_impl_buffer_await", (DL_FUNC) &_pjrt_impl_buffer_await, 1},
{"_pjrt_impl_host_data_is_ready", (DL_FUNC) &_pjrt_impl_host_data_is_ready, 1},
Expand Down
65 changes: 44 additions & 21 deletions src/buffer_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,23 +204,45 @@ static std::pair<int64_t, int64_t> build_buffer_lines_subset(
return {r, c_end};
}

// Create a contiguous span over the last two dimensions for a given leading
// index
// Extract a 2D slice (last two dimensions) into row-major order.
// For row-major input, the slice is contiguous and simply copied.
// For column-major input, elements are gathered using column-major strides.
template <typename T>
static std::span<const T> make_last2_contiguous_span(
static std::vector<T> extract_last2_slice(
const std::span<const T> &flat, const std::vector<int64_t> &pseudo_dims,
const std::vector<int64_t> &lead_index) {
const std::vector<int64_t> &lead_index, bool row_major) {
const int nprint = static_cast<int>(pseudo_dims.size());
std::vector<int64_t> stride(nprint, 1);
for (int i = nprint - 2; i >= 0; --i)
stride[i] = stride[i + 1] * pseudo_dims[i + 1];
int64_t base = 0;
for (size_t k = 0; k < lead_index.size(); ++k)
base += lead_index[k] * stride[k];
const int64_t nrows = pseudo_dims[nprint - 2];
const int64_t ncols = pseudo_dims[nprint - 1];
return std::span<const T>(flat.data() + static_cast<size_t>(base),
static_cast<size_t>(nrows * ncols));
const size_t slice_size = static_cast<size_t>(nrows * ncols);

if (row_major) {
std::vector<int64_t> stride(nprint, 1);
for (int i = nprint - 2; i >= 0; --i)
stride[i] = stride[i + 1] * pseudo_dims[i + 1];
int64_t base = 0;
for (size_t k = 0; k < lead_index.size(); ++k)
base += lead_index[k] * stride[k];
const T *start = flat.data() + static_cast<size_t>(base);
return std::vector<T>(start, start + slice_size);
}

// Column-major: gather elements using column-major strides
std::vector<int64_t> strides = dims2strides(pseudo_dims, false);
int64_t base = 0;
for (size_t k = 0; k < lead_index.size(); ++k)
base += lead_index[k] * strides[k];
int64_t row_stride = strides[nprint - 2];
int64_t col_stride = strides[nprint - 1];

std::vector<T> result(slice_size);
for (int64_t r = 0; r < nrows; ++r) {
for (int64_t c = 0; c < ncols; ++c) {
result[static_cast<size_t>(r * ncols + c)] =
flat[static_cast<size_t>(base + r * row_stride + c * col_stride)];
}
}
return result;
}

// core printer
Expand All @@ -231,7 +253,8 @@ static void print_with_formatter_fn(const std::vector<int64_t> &dimensions,
int max_width, int max_rows_slice,
int rows_left,
std::vector<std::string> &cont,
std::span<const CopyT> temp_vec) {
std::span<const CopyT> temp_vec,
bool row_major) {
const int ndim = dimensions.size();

// pseudo_dims are used so we don't have to treat the 0d and 1d cases special
Expand Down Expand Up @@ -281,10 +304,10 @@ static void print_with_formatter_fn(const std::vector<int64_t> &dimensions,
cont.push_back(hdr.str());
}

// Extract this slice as a span (because of row-major ordering,
// this data is contigous)
std::span<const CopyT> slice =
make_last2_contiguous_span<CopyT>(temp_vec, pseudo_dims, lead_index);
// Extract this slice as a row-major ordered vector
std::vector<CopyT> slice_vec =
extract_last2_slice<CopyT>(temp_vec, pseudo_dims, lead_index, row_major);
std::span<const CopyT> slice(slice_vec);

// now do the actual data printing
int64_t rows_to_print = nrows;
Expand Down Expand Up @@ -407,7 +430,7 @@ static void print_with_formatter_fn(const std::vector<int64_t> &dimensions,
std::vector<std::string> buffer_to_string_lines(
const void *data, const std::vector<int64_t> &dimensions,
PJRT_Buffer_Type element_type, int max_rows, int max_width,
int max_rows_slice) {
int max_rows_slice, bool row_major) {
int64_t numel = dimensions.empty() ? 1 : number_of_elements(dimensions);

if (numel == 0) {
Expand All @@ -425,23 +448,23 @@ std::vector<std::string> buffer_to_string_lines(
std::span<const FP> temp_span(static_cast<const FP *>(data),
static_cast<size_t>(numel));
print_with_formatter_fn(dimensions, max_width, max_rows_slice, rows_left,
cont, temp_span);
cont, temp_span, row_major);
};

auto handle_integer = [&](auto int_tag) {
using IT = decltype(int_tag);
std::span<const IT> temp_span(static_cast<const IT *>(data),
static_cast<size_t>(numel));
print_with_formatter_fn(dimensions, max_width, max_rows_slice, rows_left,
cont, temp_span);
cont, temp_span, row_major);
};

auto handle_logical = [&]() {
using BT = uint8_t;
std::span<const BT> temp_span(static_cast<const BT *>(data),
static_cast<size_t>(numel));
print_with_formatter_fn(dimensions, max_width, max_rows_slice, rows_left,
cont, temp_span);
cont, temp_span, row_major);
};

switch (element_type) {
Expand Down
2 changes: 1 addition & 1 deletion src/buffer_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ void buffer_print(Rcpp::XPtr<rpjrt::PJRTBuffer> buffer, int n = 30,
std::vector<std::string> buffer_to_string_lines(
const void* data, const std::vector<int64_t>& dimensions,
PJRT_Buffer_Type element_type, int max_rows = 30, int max_width = 85,
int max_rows_slice = 30);
int max_rows_slice = 30, bool row_major = true);
47 changes: 47 additions & 0 deletions src/pjrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,53 @@ void impl_buffer_print(Rcpp::XPtr<rpjrt::PJRTBuffer> buffer, int max_rows,
buffer_print(buffer, max_rows, max_width, max_rows_slice);
}

// [[Rcpp::export()]]
Rcpp::CharacterVector impl_format_array(SEXP data, int max_rows, int max_width,
int max_rows_slice) {
PJRT_Buffer_Type element_type;
const void *data_ptr;
std::vector<uint8_t> logical_data;

if (Rf_isReal(data)) {
element_type = PJRT_Buffer_Type_F64;
data_ptr = REAL(data);
} else if (Rf_isInteger(data)) {
element_type = PJRT_Buffer_Type_S32;
data_ptr = INTEGER(data);
} else if (Rf_isLogical(data)) {
int n = Rf_length(data);
logical_data.resize(n);
int *lgl = LOGICAL(data);
for (int i = 0; i < n; i++) {
logical_data[i] = static_cast<uint8_t>(lgl[i] != 0);
}
element_type = PJRT_Buffer_Type_PRED;
data_ptr = logical_data.data();
} else {
Rcpp::stop("Unsupported R type for formatting.");
}

SEXP dim_attr = Rf_getAttrib(data, R_DimSymbol);
std::vector<int64_t> dimensions;
if (!Rf_isNull(dim_attr)) {
int ndim = Rf_length(dim_attr);
int *dim_ptr = INTEGER(dim_attr);
for (int i = 0; i < ndim; i++) {
dimensions.push_back(static_cast<int64_t>(dim_ptr[i]));
}
} else {
int n = Rf_length(data);
if (n != 1) {
dimensions.push_back(static_cast<int64_t>(n));
}
}

auto lines = buffer_to_string_lines(data_ptr, dimensions, element_type,
max_rows, max_width, max_rows_slice,
/*row_major=*/false);
return Rcpp::wrap(lines);
}

// Async status functions for buffers and host data

// [[Rcpp::export()]]
Expand Down
Loading