Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .github/workflows/check-code-style.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: [main, master]
pull_request:

name: Check Code Style

jobs:
check-code-style:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: R code
uses: rstudio/shiny-workflows/format-r-code@v1
with:
check: true
- name: C++ code
uses: jidicula/clang-format-action@v4.13.0
with:
clang-format-version: '20'
check-path: ./src
exclude-regex: RcppExports
130 changes: 65 additions & 65 deletions src/pjrt_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace rpjrt {

class PJRTElementType {
public:
public:
explicit PJRTElementType(PJRT_Buffer_Type type) : element_type_(type) {}

PJRT_Buffer_Type get_type() const { return element_type_; }
Expand All @@ -18,73 +18,73 @@ class PJRTElementType {

std::string as_string() const {
switch (element_type_) {
case PJRT_Buffer_Type_INVALID:
return "INVALID";
case PJRT_Buffer_Type_PRED:
return "PRED";
case PJRT_Buffer_Type_S8:
return "S8";
case PJRT_Buffer_Type_S16:
return "S16";
case PJRT_Buffer_Type_S32:
return "S32";
case PJRT_Buffer_Type_S64:
return "S64";
case PJRT_Buffer_Type_U8:
return "U8";
case PJRT_Buffer_Type_U16:
return "U16";
case PJRT_Buffer_Type_U32:
return "U32";
case PJRT_Buffer_Type_U64:
return "U64";
case PJRT_Buffer_Type_F16:
return "F16";
case PJRT_Buffer_Type_F32:
return "F32";
case PJRT_Buffer_Type_F64:
return "F64";
case PJRT_Buffer_Type_BF16:
return "BF16";
case PJRT_Buffer_Type_C64:
return "C64";
case PJRT_Buffer_Type_C128:
return "C128";
case PJRT_Buffer_Type_F8E5M2:
return "F8E5M2";
case PJRT_Buffer_Type_F8E4M3FN:
return "F8E4M3FN";
case PJRT_Buffer_Type_F8E4M3B11FNUZ:
return "F8E4M3B11FNUZ";
case PJRT_Buffer_Type_F8E5M2FNUZ:
return "F8E5M2FNUZ";
case PJRT_Buffer_Type_F8E4M3FNUZ:
return "F8E4M3FNUZ";
case PJRT_Buffer_Type_S4:
return "S4";
case PJRT_Buffer_Type_U4:
return "U4";
case PJRT_Buffer_Type_TOKEN:
return "TOKEN";
case PJRT_Buffer_Type_S2:
return "S2";
case PJRT_Buffer_Type_U2:
return "U2";
case PJRT_Buffer_Type_F8E4M3:
return "F8E4M3";
case PJRT_Buffer_Type_F8E3M4:
return "F8E3M4";
case PJRT_Buffer_Type_F8E8M0FNU:
return "F8E8M0FNU";
case PJRT_Buffer_Type_F4E2M1FN:
return "F4E2M1FN";
default:
return "UNKNOWN(" + std::to_string(as_integer()) + ")";
case PJRT_Buffer_Type_INVALID:
return "INVALID";
case PJRT_Buffer_Type_PRED:
return "PRED";
case PJRT_Buffer_Type_S8:
return "S8";
case PJRT_Buffer_Type_S16:
return "S16";
case PJRT_Buffer_Type_S32:
return "S32";
case PJRT_Buffer_Type_S64:
return "S64";
case PJRT_Buffer_Type_U8:
return "U8";
case PJRT_Buffer_Type_U16:
return "U16";
case PJRT_Buffer_Type_U32:
return "U32";
case PJRT_Buffer_Type_U64:
return "U64";
case PJRT_Buffer_Type_F16:
return "F16";
case PJRT_Buffer_Type_F32:
return "F32";
case PJRT_Buffer_Type_F64:
return "F64";
case PJRT_Buffer_Type_BF16:
return "BF16";
case PJRT_Buffer_Type_C64:
return "C64";
case PJRT_Buffer_Type_C128:
return "C128";
case PJRT_Buffer_Type_F8E5M2:
return "F8E5M2";
case PJRT_Buffer_Type_F8E4M3FN:
return "F8E4M3FN";
case PJRT_Buffer_Type_F8E4M3B11FNUZ:
return "F8E4M3B11FNUZ";
case PJRT_Buffer_Type_F8E5M2FNUZ:
return "F8E5M2FNUZ";
case PJRT_Buffer_Type_F8E4M3FNUZ:
return "F8E4M3FNUZ";
case PJRT_Buffer_Type_S4:
return "S4";
case PJRT_Buffer_Type_U4:
return "U4";
case PJRT_Buffer_Type_TOKEN:
return "TOKEN";
case PJRT_Buffer_Type_S2:
return "S2";
case PJRT_Buffer_Type_U2:
return "U2";
case PJRT_Buffer_Type_F8E4M3:
return "F8E4M3";
case PJRT_Buffer_Type_F8E3M4:
return "F8E3M4";
case PJRT_Buffer_Type_F8E8M0FNU:
return "F8E8M0FNU";
case PJRT_Buffer_Type_F4E2M1FN:
return "F4E2M1FN";
default:
return "UNKNOWN(" + std::to_string(as_integer()) + ")";
}
}

private:
private:
PJRT_Buffer_Type element_type_;
};

} // namespace rpjrt
} // namespace rpjrt
61 changes: 30 additions & 31 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,21 @@ PJRTProgramFormat PJRTProgram::format() const {
}
}

std::string
PJRTProgram::load_program_from_file(const std::string &fname,
const PJRTProgramFormat &format) {
std::string PJRTProgram::load_program_from_file(
const std::string &fname, const PJRTProgramFormat &format) {
std::ifstream input(fname, std::ios::binary | std::ios::ate);
std::streamsize size = input.tellg();
input.seekg(0, std::ios::beg); // rewind
input.seekg(0, std::ios::beg); // rewind
std::vector<char> buffer(size);
input.read(buffer.data(), size);

switch (format) {
case HLO:
return parse_hlo_program(buffer);
break;
case MLIR:
return std::string(buffer.data(), buffer.size());
break;
case HLO:
return parse_hlo_program(buffer);
break;
case MLIR:
return std::string(buffer.data(), buffer.size());
break;
};

throw std::runtime_error("Unknown program format");
Expand All @@ -52,16 +51,16 @@ PJRT_Program PJRTProgram::create_program(std::string &code,
program.code_size = code.size();

switch (format) {
case HLO:
program.format = "hlo";
program.format_size = strlen("hlo");
break;
case MLIR:
program.format = "mlir";
program.format_size = strlen("mlir");
break;
default:
throw std::runtime_error("Unknown program format");
case HLO:
program.format = "hlo";
program.format_size = strlen("hlo");
break;
case MLIR:
program.format = "mlir";
program.format_size = strlen("mlir");
break;
default:
throw std::runtime_error("Unknown program format");
}

return program;
Expand All @@ -80,14 +79,14 @@ std::string PJRTProgram::repr(int n) const {
auto format = this->format();
std::string debug("");
switch (format) {
case HLO: {
xla::HloModuleProto hlo_proto{};
hlo_proto.ParseFromArray(this->code.data(), this->code.size());
debug = hlo_proto.DebugString();
} break;
case MLIR:
debug = this->code;
break;
case HLO: {
xla::HloModuleProto hlo_proto{};
hlo_proto.ParseFromArray(this->code.data(), this->code.size());
debug = hlo_proto.DebugString();
} break;
case MLIR:
debug = this->code;
break;
}

// debug must not be larger than n lines
Expand All @@ -99,7 +98,7 @@ std::string PJRTProgram::repr(int n) const {
pos = debug.find('\n', pos);
if (pos != std::string::npos) {
lines_found++;
pos++; // Move past the newline
pos++; // Move past the newline
}
}

Expand All @@ -108,11 +107,11 @@ std::string PJRTProgram::repr(int n) const {
// Check if there's more content after the nth line
if (pos < debug.length()) {
debug = debug.substr(0, pos - 1) +
"\n..."; // pos-1 to include the nth newline
"\n..."; // pos-1 to include the nth newline
}
}

return repr + "\n" + debug;
}

} // namespace rpjrt
} // namespace rpjrt
6 changes: 3 additions & 3 deletions src/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ namespace rpjrt {
enum PJRTProgramFormat { HLO, MLIR };

class PJRTProgram {
public:
public:
PJRTProgram(const std::string &fname, const PJRTProgramFormat &format);
PJRTProgramFormat format() const;
std::string repr(int n) const;
std::string code;
PJRT_Program program;

private:
private:
static std::string load_program_from_file(const std::string &fname,
const PJRTProgramFormat &format);
static std::string parse_hlo_program(const std::vector<char> &buffer);
static PJRT_Program create_program(std::string &code,
const PJRTProgramFormat &format);
};
} // namespace rpjrt
} // namespace rpjrt
Loading