Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions examples/BuddyDeepSeekR1/buddy-deepseek-r1-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,104 @@
// limitations under the License.
//
//===----------------------------------------------------------------------===//
#include <algorithm>
#include <buddy/Core/Container.h>
#include <chrono>
#include <cstring>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <map>
#include <numeric>
#include <random>
#include <string>
#include <vector>

// ===== Operator Timing Infrastructure =====

// Timing data structure
struct TimingRecord {
std::string op_name;
std::vector<double> times_ms;

void add_time(double time_sec) {
times_ms.push_back(time_sec * 1000.0); // Convert to milliseconds
}

double get_total() const {
return std::accumulate(times_ms.begin(), times_ms.end(), 0.0);
}
};

// Global timing data storage
static std::map<std::string, TimingRecord> g_timing_data;

// Timing functions called from MLIR
extern "C" {
// Get current time in seconds
double rtclock() {
auto now = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double>(now.time_since_epoch()).count();
}

// MLIR C interface wrapper for rtclock
double _mlir_ciface_rtclock() { return rtclock(); }

// Record timing for an operator
void record_timing(const char *op_name, double duration_sec) {
std::string name(op_name);
g_timing_data[name].op_name = name;
g_timing_data[name].add_time(duration_sec);
}

// MLIR C interface wrapper for record_timing
void _mlir_ciface_record_timing(void *op_name_ptr, double duration_sec) {
const char *op_name = reinterpret_cast<const char *>(op_name_ptr);
record_timing(op_name, duration_sec);
}
}

void print_timing_report() {
std::cout << "\n";
std::cout << "========================================\n";
std::cout << " Operator Timing Report\n";
std::cout << "========================================\n";
std::cout << std::fixed << std::setprecision(4);

// compute total time
double total_time = 0.0;
for (const auto &[name, record] : g_timing_data) {
total_time += record.get_total();
}

// print table header
std::cout << std::left << std::setw(30) << "Operator" << std::right
<< std::setw(16) << "Total (ms)" << std::setw(12) << "% Total"
<< "\n";
std::cout << "----------------------------------------"
<< "------------------------------\n";

// print each operator time
for (const auto &[name, record] : g_timing_data) {
double total = record.get_total();
double percentage = (total_time > 0) ? (total / total_time * 100.0) : 0.0;

std::cout << std::left << std::setw(30) << name << std::right
<< std::setw(16) << total << std::setw(11) << percentage << "%\n";
}

// print total time
std::cout << "----------------------------------------"
<< "------------------------------\n";
std::cout << std::left << std::setw(30) << "TOTAL" << std::right
<< std::setw(16) << total_time << std::setw(11) << "100.0%\n";
std::cout << "========================================\n\n";
}

// Clear timing data (for warmup)
void clear_timing_data() { g_timing_data.clear(); }

// ===== End of Timing Infrastructure =====

#include <array>
#include <buddy/Core/Container.h>
Expand Down Expand Up @@ -466,6 +564,9 @@ int main() {
// Print the generated token and inference time.
printIterInfo(i, tok, inferenceTime.count() / 1000);

print_timing_report();
clear_timing_data();

// Stop if a <|end▁of▁sentence|> token is generated.
if (maxIndex == 151643) {
break;
Expand Down
Loading
Loading