diff --git a/.github/workflows/check-code-style.yaml b/.github/workflows/check-code-style.yaml new file mode 100644 index 00000000..ae99dfb4 --- /dev/null +++ b/.github/workflows/check-code-style.yaml @@ -0,0 +1,25 @@ +# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples +# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help +on: + push: + branches: [main, master] + pull_request: + +name: Check Code Style + +jobs: + check-code-style: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: R code + uses: rstudio/shiny-workflows/format-r-code@v1 + with: + check: true + - name: C++ code + uses: jidicula/clang-format-action@v4.13.0 + with: + clang-format-version: '20' + check-path: ./src + exclude-regex: RcppExports diff --git a/src/pjrt_types.h b/src/pjrt_types.h index 7f26df09..bdc49e64 100644 --- a/src/pjrt_types.h +++ b/src/pjrt_types.h @@ -9,7 +9,7 @@ namespace rpjrt { class PJRTElementType { -public: + public: explicit PJRTElementType(PJRT_Buffer_Type type) : element_type_(type) {} PJRT_Buffer_Type get_type() const { return element_type_; } @@ -18,73 +18,73 @@ class PJRTElementType { std::string as_string() const { switch (element_type_) { - case PJRT_Buffer_Type_INVALID: - return "INVALID"; - case PJRT_Buffer_Type_PRED: - return "PRED"; - case PJRT_Buffer_Type_S8: - return "S8"; - case PJRT_Buffer_Type_S16: - return "S16"; - case PJRT_Buffer_Type_S32: - return "S32"; - case PJRT_Buffer_Type_S64: - return "S64"; - case PJRT_Buffer_Type_U8: - return "U8"; - case PJRT_Buffer_Type_U16: - return "U16"; - case PJRT_Buffer_Type_U32: - return "U32"; - case PJRT_Buffer_Type_U64: - return "U64"; - case PJRT_Buffer_Type_F16: - return "F16"; - case PJRT_Buffer_Type_F32: - return "F32"; - case PJRT_Buffer_Type_F64: - return "F64"; - case PJRT_Buffer_Type_BF16: - return "BF16"; - case PJRT_Buffer_Type_C64: - return "C64"; - case PJRT_Buffer_Type_C128: - return "C128"; - case PJRT_Buffer_Type_F8E5M2: - return "F8E5M2"; - case PJRT_Buffer_Type_F8E4M3FN: - return "F8E4M3FN"; - case PJRT_Buffer_Type_F8E4M3B11FNUZ: - return "F8E4M3B11FNUZ"; - case PJRT_Buffer_Type_F8E5M2FNUZ: - return "F8E5M2FNUZ"; - case PJRT_Buffer_Type_F8E4M3FNUZ: - return "F8E4M3FNUZ"; - case PJRT_Buffer_Type_S4: - return "S4"; - case PJRT_Buffer_Type_U4: - return "U4"; - case PJRT_Buffer_Type_TOKEN: - return "TOKEN"; - case PJRT_Buffer_Type_S2: - return "S2"; - case PJRT_Buffer_Type_U2: - return "U2"; - case PJRT_Buffer_Type_F8E4M3: - return "F8E4M3"; - case PJRT_Buffer_Type_F8E3M4: - return "F8E3M4"; - case PJRT_Buffer_Type_F8E8M0FNU: - return "F8E8M0FNU"; - case PJRT_Buffer_Type_F4E2M1FN: - return "F4E2M1FN"; - default: - return "UNKNOWN(" + std::to_string(as_integer()) + ")"; + case PJRT_Buffer_Type_INVALID: + return "INVALID"; + case PJRT_Buffer_Type_PRED: + return "PRED"; + case PJRT_Buffer_Type_S8: + return "S8"; + case PJRT_Buffer_Type_S16: + return "S16"; + case PJRT_Buffer_Type_S32: + return "S32"; + case PJRT_Buffer_Type_S64: + return "S64"; + case PJRT_Buffer_Type_U8: + return "U8"; + case PJRT_Buffer_Type_U16: + return "U16"; + case PJRT_Buffer_Type_U32: + return "U32"; + case PJRT_Buffer_Type_U64: + return "U64"; + case PJRT_Buffer_Type_F16: + return "F16"; + case PJRT_Buffer_Type_F32: + return "F32"; + case PJRT_Buffer_Type_F64: + return "F64"; + case PJRT_Buffer_Type_BF16: + return "BF16"; + case PJRT_Buffer_Type_C64: + return "C64"; + case PJRT_Buffer_Type_C128: + return "C128"; + case PJRT_Buffer_Type_F8E5M2: + return "F8E5M2"; + case PJRT_Buffer_Type_F8E4M3FN: + return "F8E4M3FN"; + case PJRT_Buffer_Type_F8E4M3B11FNUZ: + return "F8E4M3B11FNUZ"; + case PJRT_Buffer_Type_F8E5M2FNUZ: + return "F8E5M2FNUZ"; + case PJRT_Buffer_Type_F8E4M3FNUZ: + return "F8E4M3FNUZ"; + case PJRT_Buffer_Type_S4: + return "S4"; + case PJRT_Buffer_Type_U4: + return "U4"; + case PJRT_Buffer_Type_TOKEN: + return "TOKEN"; + case PJRT_Buffer_Type_S2: + return "S2"; + case PJRT_Buffer_Type_U2: + return "U2"; + case PJRT_Buffer_Type_F8E4M3: + return "F8E4M3"; + case PJRT_Buffer_Type_F8E3M4: + return "F8E3M4"; + case PJRT_Buffer_Type_F8E8M0FNU: + return "F8E8M0FNU"; + case PJRT_Buffer_Type_F4E2M1FN: + return "F4E2M1FN"; + default: + return "UNKNOWN(" + std::to_string(as_integer()) + ")"; } } -private: + private: PJRT_Buffer_Type element_type_; }; -} // namespace rpjrt +} // namespace rpjrt diff --git a/src/program.cpp b/src/program.cpp index 919b057d..e3901ee6 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -22,22 +22,21 @@ PJRTProgramFormat PJRTProgram::format() const { } } -std::string -PJRTProgram::load_program_from_file(const std::string &fname, - const PJRTProgramFormat &format) { +std::string PJRTProgram::load_program_from_file( + const std::string &fname, const PJRTProgramFormat &format) { std::ifstream input(fname, std::ios::binary | std::ios::ate); std::streamsize size = input.tellg(); - input.seekg(0, std::ios::beg); // rewind + input.seekg(0, std::ios::beg); // rewind std::vector buffer(size); input.read(buffer.data(), size); switch (format) { - case HLO: - return parse_hlo_program(buffer); - break; - case MLIR: - return std::string(buffer.data(), buffer.size()); - break; + case HLO: + return parse_hlo_program(buffer); + break; + case MLIR: + return std::string(buffer.data(), buffer.size()); + break; }; throw std::runtime_error("Unknown program format"); @@ -52,16 +51,16 @@ PJRT_Program PJRTProgram::create_program(std::string &code, program.code_size = code.size(); switch (format) { - case HLO: - program.format = "hlo"; - program.format_size = strlen("hlo"); - break; - case MLIR: - program.format = "mlir"; - program.format_size = strlen("mlir"); - break; - default: - throw std::runtime_error("Unknown program format"); + case HLO: + program.format = "hlo"; + program.format_size = strlen("hlo"); + break; + case MLIR: + program.format = "mlir"; + program.format_size = strlen("mlir"); + break; + default: + throw std::runtime_error("Unknown program format"); } return program; @@ -80,14 +79,14 @@ std::string PJRTProgram::repr(int n) const { auto format = this->format(); std::string debug(""); switch (format) { - case HLO: { - xla::HloModuleProto hlo_proto{}; - hlo_proto.ParseFromArray(this->code.data(), this->code.size()); - debug = hlo_proto.DebugString(); - } break; - case MLIR: - debug = this->code; - break; + case HLO: { + xla::HloModuleProto hlo_proto{}; + hlo_proto.ParseFromArray(this->code.data(), this->code.size()); + debug = hlo_proto.DebugString(); + } break; + case MLIR: + debug = this->code; + break; } // debug must not be larger than n lines @@ -99,7 +98,7 @@ std::string PJRTProgram::repr(int n) const { pos = debug.find('\n', pos); if (pos != std::string::npos) { lines_found++; - pos++; // Move past the newline + pos++; // Move past the newline } } @@ -108,11 +107,11 @@ std::string PJRTProgram::repr(int n) const { // Check if there's more content after the nth line if (pos < debug.length()) { debug = debug.substr(0, pos - 1) + - "\n..."; // pos-1 to include the nth newline + "\n..."; // pos-1 to include the nth newline } } return repr + "\n" + debug; } -} // namespace rpjrt +} // namespace rpjrt diff --git a/src/program.h b/src/program.h index 295753cb..77f1e6dd 100644 --- a/src/program.h +++ b/src/program.h @@ -9,18 +9,18 @@ namespace rpjrt { enum PJRTProgramFormat { HLO, MLIR }; class PJRTProgram { -public: + public: PJRTProgram(const std::string &fname, const PJRTProgramFormat &format); PJRTProgramFormat format() const; std::string repr(int n) const; std::string code; PJRT_Program program; -private: + private: static std::string load_program_from_file(const std::string &fname, const PJRTProgramFormat &format); static std::string parse_hlo_program(const std::vector &buffer); static PJRT_Program create_program(std::string &code, const PJRTProgramFormat &format); }; -} // namespace rpjrt +} // namespace rpjrt