diff --git a/examples/SplitLlama/BaseDisModel.h b/examples/SplitLlama/BaseDisModel.h index 59304b66..1c74b6ec 100644 --- a/examples/SplitLlama/BaseDisModel.h +++ b/examples/SplitLlama/BaseDisModel.h @@ -12,6 +12,9 @@ #include #include #include +#include +#include +#include using namespace buddy; class BaseDisModel { @@ -21,7 +24,7 @@ class BaseDisModel { /// Load parameters into data container. static void loadParameters(const std::string ¶mFilePath, - MemRef ¶ms) { + MemRef ¶ms) { const auto loadStart = std::chrono::high_resolution_clock::now(); std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary); if (!paramFile.is_open()) { @@ -48,9 +51,9 @@ class BaseDisModel { << std::endl; } - static void getParameters(const size_t *paramSize_group, size_t group_len, int size, - const std::string &splitNum, - std::vector> ¶msContainers) { + static void getParameters(const size_t *paramSize_group, size_t group_len, + int size, const std::string &splitNum, + std::vector> ¶msContainers) { std::string llamaBuildDir = LLAMA_EXAMPLE_BUILD_PATH; @@ -65,6 +68,130 @@ class BaseDisModel { } } } + + // Tokenize input data in the container. + static void tokenizeInput(const std::string &vocabFile, + Text &inputContainer, + const size_t MaxTokenLength) { + printLogLabel(); + std::cout << "Vocab file: " << std::filesystem::canonical(vocabFile) + << std::endl; + const auto buddyTokenizeStart = std::chrono::high_resolution_clock::now(); + inputContainer.tokenizeLlama(vocabFile, MaxTokenLength); + const auto buddyTokenizeEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration buddyTokenizeTime = + buddyTokenizeEnd - buddyTokenizeStart; + printLogLabel(); + std::cout << "Tokenize time: " << buddyTokenizeTime.count() << "ms" + << std::endl; + } + + // Add the inference Token to the input sequence , and translate the result of + // that Token inputContainer: Current input container, token will be appended + // msg: Token to be appended + static void appendToken(Text &inputContainer, std::string &msg) { + int maxIndex = std::stoi(msg); + // Determine the generated token. + int tokenIndex = inputContainer.getTokenCnt() - 1; + std::string tok = inputContainer.getStr(maxIndex); + // printIterInfo + std::cout << "\033[32;1m[Iteration " << tokenIndex << "] \033[0m"; + std::cout << "Token: " << tok << std::endl; + + // Append the generated token into the input and output container. + inputContainer.appendTokenIdx(maxIndex); + } + + // Generate a Token based on the inference result and add the token to the + // output sequence Return value: the index of the token generated; if the + // generation is finished (the terminator is encountered), -1 is returned. + // resultContainer: container for the result of one inference of the model + // outputContainer: store the output sequence of tokens + // currentToken: the number of tokens currently generated + // tokenCnt: the number of tokens in the original input + // MaxVocabSize:The maximum number of tokens allowed in the model's + // vocabulary. separatorTokenIndex: The vocabulary index of the + // end-of-inference token. + static int generatedToken(MemRef &resultContainer, + Text &outputContainer, + uint32_t ¤tToken, uint32_t tokenCnt, + const size_t MaxVocabSize, + const size_t separatorTokenIndex) { + int tokenIndex = currentToken + tokenCnt - 1; + currentToken++; + // Determine the generated token. + const float *startPtr = + resultContainer.getData() + tokenIndex * MaxVocabSize; + const float *endPtr = startPtr + MaxVocabSize; + // int maxIndex = findMaxIndex(startPtr, endPtr); + int maxIndex = std::distance(startPtr, std::max_element(startPtr, endPtr)); + + // Stop if a separator token or line break token (13<0x0A>) is generated. + if (maxIndex == separatorTokenIndex) + return -1; + + outputContainer.appendTokenIdx(maxIndex); + + return maxIndex; + } + + static void send_data(websocketpp::connection_hdl hdl, uint32_t dataId, + const std::vector> &data, + websocketpp::server &server) { + const uint8_t total = data.size(); + + auto con = server.get_con_from_hdl(hdl); + if (!con || con->get_state() != websocketpp::session::state::open) { + std::cerr << "连接已关闭,无法发送数据。" << std::endl; + return; + } + + for (uint8_t i = 0; i < total; ++i) { + const auto &subdata = data[i]; + + // 构造协议头 + std::vector packet(10); // 4+1+1+2=8字节头 + memcpy(packet.data(), &dataId, 4); + packet[4] = total; + packet[5] = i; + uint32_t num = subdata.size(); + memcpy(packet.data() + 6, &num, 4); + + // 添加浮点数据 + const uint8_t *binaryData = + reinterpret_cast(subdata.data()); + packet.insert(packet.end(), binaryData, + binaryData + subdata.size() * sizeof(float)); + + server.send(hdl, packet.data(), packet.size(), + websocketpp::frame::opcode::binary); + } + } + + static void sendToClient(const std::map>> &nameToDataVecs, + const std::map &hdlMap, + uint32_t dataId, + websocketpp::server &server) { + for (const auto &[name, dataVecs] : nameToDataVecs) { + auto it = hdlMap.find(name); + if (it == hdlMap.end()) { + std::cerr << "[Warning] 未找到连接: " << name + << std::endl; + continue; + } + websocketpp::connection_hdl hdl = it->second; + + // // 调用已有的 send_data 封装 + // send_data(hdl, dataId, data, server); + try { + send_data(hdl, dataId++, dataVecs, server); + std::cout << "成功向"<< name <<"发送数据" << std::endl; + + } catch (const websocketpp::exception &e) { + std::cerr << "[Error] 向 " << name << " 发送失败: " << e.what() << std::endl; + } + } + } }; #endif // BASEDISMODEL_H diff --git a/examples/SplitLlama/SharedQueueTemp.h b/examples/SplitLlama/SharedQueueTemp.h index 0375cc42..42fdada5 100644 --- a/examples/SplitLlama/SharedQueueTemp.h +++ b/examples/SplitLlama/SharedQueueTemp.h @@ -1,5 +1,6 @@ -#ifndef SHAREDQUEUE_TEMP_H -#define SHAREDQUEUE_TEMP_H +#ifndef SHAREDQUEUETEMP_H +#define SHAREDQUEUETEMP_H + #include #include #include @@ -8,12 +9,12 @@ #include #include -// using namespace buddy; -/// 通用共享内存类:用于线程间通信 +/// Generic shared memory classes: for inter-thread communication + class SharedQueueTemp { public: - /// 构造函数:传入你希望支持的队列名称,如 {"input", "input0", "input1", - /// "output"} + /// Constructor: pass the name of the queue you wish to support, e.g. + /// {"input", "input0", "input1", "output"} SharedQueueTemp(const std::vector &queueNames) { for (const auto &name : queueNames) { queues[name] = std::queue(); @@ -22,7 +23,6 @@ class SharedQueueTemp { } } - /// 向指定队列 push 数据(任何类型) template void push(const std::string &queueName, const T &data) { checkQueueExists(queueName); { @@ -32,7 +32,6 @@ class SharedQueueTemp { cvs[queueName]->notify_one(); } - /// 从指定队列 pop 数据(阻塞,直到有数据) template T pop(const std::string &queueName) { checkQueueExists(queueName); std::unique_lock lock(*mutexes[queueName]); diff --git a/examples/SplitLlama/llamaAdd.h b/examples/SplitLlama/llamaAdd.h index 096a395e..3b197641 100644 --- a/examples/SplitLlama/llamaAdd.h +++ b/examples/SplitLlama/llamaAdd.h @@ -1,5 +1,7 @@ #ifndef LLAMAAdd_H // 作用:防止llamaAdd.h被重复引用 #define LLAMAAdd_H +#include "SharedQueueTemp.h" +#include "BaseDisModel.h" #include #include #include @@ -22,7 +24,6 @@ #include #include #include -#include "SharedQueueTemp.h" using namespace buddy; using websocketpp::lib::bind; @@ -46,8 +47,7 @@ extern "C" void _mlir_ciface_forward3(MemRef *, MemRef *, //--------------------- AddMess (主线程) --------------------- class AddQueue : public SharedQueueTemp { public: - AddQueue() - : SharedQueueTemp({"input", "input0", "input1", "output"}) {} + AddQueue() : SharedQueueTemp({"input", "input0", "input1", "output"}) {} }; class AddMess { @@ -148,51 +148,36 @@ class AddMess { resultContainer = sharedQueue.pop>("output"); std::lock_guard lock(symbolMutex); // 加锁保护符号表 if (isLast) { - if (tfCount == 31) { - auto it = hdlsSymbol.find("OutputMess"); - if (it != hdlsSymbol.end()) { - send_data(hdlsSymbol["OutputMess"], dataId++, - {resultContainer.getDataVector()}); - tfCount = 0; - std::cout << "一次Token推理完成." << std::endl; + if (tfCount == 31) { + std::map>> sendMap = { + {"OutputMess", {resultContainer.getDataVector()}}}; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, addServer); + tfCount = 0; + + } else if (tfCount < 31) { + std::map>> sendMap = { + {"FirstRMS", {resultContainer.getDataVector()}}}; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, addServer); + tfCount++; } else { - std::cout << "OutputMess未连接, 转发失败" << std::endl; + std::cout << "transformer层推理次数过多" << std::endl; } - } else if (tfCount < 31) { - auto it = hdlsSymbol.find("FirstRMS"); - if (it != hdlsSymbol.end()) { - send_data(hdlsSymbol["FirstRMS"], dataId++, - {resultContainer.getDataVector()}); - tfCount++; - std::cout << "第" << tfCount << "次transformer层推理完成." - << std::endl; - } else { - std::cout << "未连接FirstRMS, 转发失败" << std::endl; - } - } else { - std::cout << "transformer层推理次数过多" << std::endl; - } - } else { - auto it = hdlsSymbol.find("RMSMess"); - if (it != hdlsSymbol.end()) { - send_data(hdlsSymbol["RMSMess"], dataId++, - {resultContainer.getDataVector()}); - std::cout << "转发成功 " << std::endl; - } else { - std::cout << "RMSMess未连接, 转发失败" << std::endl; + } else { + std::map>> sendMap = { + {"RMSMess", {resultContainer.getDataVector()}}}; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, addServer); } } + }); + if (isLast) { + rmsClient0_thread = std::thread([this]() { rmsClient0.run(); }); } - }); - if (isLast) { - rmsClient0_thread = std::thread([this]() { rmsClient0.run(); }); - } - rmsClient_thread.join(); - mhaClient_thread.join(); - mhaClient0_thread.join(); - server_thread.join(); - output_thread.join(); - rmsClient0_thread.join(); + rmsClient_thread.join(); + mhaClient_thread.join(); + mhaClient0_thread.join(); + server_thread.join(); + output_thread.join(); + rmsClient0_thread.join(); } private: @@ -218,148 +203,119 @@ class AddMess { // 是否是最后一个add模块 bool isLast; - void send_data(websocketpp::connection_hdl hdl, uint32_t dataId, - const std::vector> &data) { - const uint8_t total = data.size(); - if (addServer.get_con_from_hdl(hdl)->get_state() != - websocketpp::session::state::open) - return; - - for (uint8_t i = 0; i < total; ++i) { - const auto &subdata = data[i]; - - // 构造协议头 - std::vector packet(10); // 4+1+1+2=8字节头 - memcpy(packet.data(), &dataId, 4); - packet[4] = total; - packet[5] = i; - uint32_t num = subdata.size(); - memcpy(packet.data() + 6, &num, 4); - - // 添加浮点数据 - const uint8_t *binaryData = - reinterpret_cast(subdata.data()); - packet.insert(packet.end(), binaryData, - binaryData + subdata.size() * sizeof(float)); - - addServer.send(hdl, packet.data(), packet.size(), - websocketpp::frame::opcode::binary); - } - } - void on_server_message(websocketpp::connection_hdl hdl, server::message_ptr msg) { - std::string payload = msg->get_payload(); - if (payload.find("RMSMess") != std::string::npos) { - std::lock_guard lock(symbolMutex); // 加锁保护符号表 - hdlsSymbol["RMSMess"] = hdl; - connections[hdl] = payload; - std::cout << payload << " 已连接" << std::endl; - return; - } else if (payload.find("OutputMess") != std::string::npos) { - std::lock_guard lock(symbolMutex); // 加锁保护符号表 - hdlsSymbol[payload] = hdl; - connections[hdl] = payload; - std::cout << payload << " 已连接" << std::endl; - return; - } + std::string payload = msg->get_payload(); + if (payload.find("RMSMess") != std::string::npos) { + std::lock_guard lock(symbolMutex); // 加锁保护符号表 + hdlsSymbol["RMSMess"] = hdl; + connections[hdl] = payload; + std::cout << payload << " 已连接" << std::endl; + return; + } else if (payload.find("OutputMess") != std::string::npos) { + std::lock_guard lock(symbolMutex); // 加锁保护符号表 + hdlsSymbol[payload] = hdl; + connections[hdl] = payload; + std::cout << payload << " 已连接" << std::endl; + return; + } } std::vector getFloatData(client::message_ptr msg) { - if (msg->get_opcode() != websocketpp::frame::opcode::binary) { - std::cout << "忽略非二进制消息" << std::endl; - return {}; - } - - const std::string &payload = msg->get_payload(); - if (payload.size() < 10) { - std::cerr << "错误: 协议头不完整(需要至少10字节)" << std::endl; - return {}; - } - - // 解析协议头 - uint32_t batch_id; - uint8_t totalChunks, seqChunk; - uint32_t num_elements; + if (msg->get_opcode() != websocketpp::frame::opcode::binary) { + std::cout << "忽略非二进制消息" << std::endl; + return {}; + } - memcpy(&batch_id, payload.data(), 4); - totalChunks = payload[4]; - seqChunk = payload[5]; - memcpy(&num_elements, payload.data() + 6, 4); + const std::string &payload = msg->get_payload(); + if (payload.size() < 10) { + std::cerr << "错误: 协议头不完整(需要至少10字节)" << std::endl; + return {}; + } - // 验证分块序号有效性 - if (seqChunk >= totalChunks) { - std::cerr << "错误:非法分块序号 " << (int)seqChunk - << " (总块数=" << (int)totalChunks << ")" << std::endl; - return {}; - } + // 解析协议头 + uint32_t batch_id; + uint8_t totalChunks, seqChunk; + uint32_t num_elements; + + memcpy(&batch_id, payload.data(), 4); + totalChunks = payload[4]; + seqChunk = payload[5]; + memcpy(&num_elements, payload.data() + 6, 4); + + // 验证分块序号有效性 + if (seqChunk >= totalChunks) { + std::cerr << "错误:非法分块序号 " << (int)seqChunk + << " (总块数=" << (int)totalChunks << ")" << std::endl; + return {}; + } - // 验证数据长度 - const size_t expectedSize = 10 + num_elements * sizeof(float); - if (payload.size() != expectedSize) { - std::cerr << "错误:数据长度不匹配(预期=" << expectedSize - << " 实际=" << payload.size() << ")" << std::endl; - return {}; - } + // 验证数据长度 + const size_t expectedSize = 10 + num_elements * sizeof(float); + if (payload.size() != expectedSize) { + std::cerr << "错误:数据长度不匹配(预期=" << expectedSize + << " 实际=" << payload.size() << ")" << std::endl; + return {}; + } - // 提取浮点数据 - const float *float_data = - reinterpret_cast(payload.data() + 10); - std::vector chunk(float_data, float_data + num_elements); - return chunk; + // 提取浮点数据 + const float *float_data = + reinterpret_cast(payload.data() + 10); + std::vector chunk(float_data, float_data + num_elements); + return chunk; } void on_mhaClient_message(websocketpp::connection_hdl hdl, client::message_ptr msg) { - std::lock_guard lock(dataMutex); - auto chunk = getFloatData(msg); - intptr_t sizes[2] = {SubMaxTokenLength, HiddenSize}; - MemRef subResultContainer(chunk.data(), sizes); - sharedQueue.push("input0", subResultContainer); - std::cout << "接收到MHAMess0数据" << std::endl; + std::lock_guard lock(dataMutex); + auto chunk = getFloatData(msg); + intptr_t sizes[2] = {SubMaxTokenLength, HiddenSize}; + MemRef subResultContainer(chunk.data(), sizes); + sharedQueue.push("input0", subResultContainer); + std::cout << "接收到MHAMess0数据" << std::endl; } void on_mhaClient0_message(websocketpp::connection_hdl hdl, client::message_ptr msg) { - std::lock_guard lock(dataMutex); - auto chunk = getFloatData(msg); - intptr_t sizes[2] = {SubMaxTokenLength, HiddenSize}; - MemRef subResultContainer(chunk.data(), sizes); - sharedQueue.push("input1", subResultContainer); - std::cout << "接收到MHAMess1数据" << std::endl; + std::lock_guard lock(dataMutex); + auto chunk = getFloatData(msg); + intptr_t sizes[2] = {SubMaxTokenLength, HiddenSize}; + MemRef subResultContainer(chunk.data(), sizes); + sharedQueue.push("input1", subResultContainer); + std::cout << "接收到MHAMess1数据" << std::endl; } void on_rmsClient_message(websocketpp::connection_hdl hdl, client::message_ptr msg) { - std::lock_guard lock(dataMutex); - auto chunk = getFloatData(msg); - intptr_t sizes[3] = {1, SubMaxTokenLength, HiddenSize}; - MemRef subResultContainer(chunk.data(), sizes); - sharedQueue.push>("input", subResultContainer); - std::cout << "接收到RMSMess数据" << std::endl; + std::lock_guard lock(dataMutex); + auto chunk = getFloatData(msg); + intptr_t sizes[3] = {1, SubMaxTokenLength, HiddenSize}; + MemRef subResultContainer(chunk.data(), sizes); + sharedQueue.push>("input", subResultContainer); + std::cout << "接收到RMSMess数据" << std::endl; } -}; + }; -//--------------------- Comp (子线程) --------------------- -class Comp { -public: - Comp(AddQueue &queue) : sharedQueue(queue) {} + //--------------------- Comp (子线程) --------------------- + class Comp { + public: + Comp(AddQueue &queue) : sharedQueue(queue) {} - void run() { - while (true) { - auto input1 = sharedQueue.pop>("input1"); - auto input2 = sharedQueue.pop>("input"); - auto input0 = sharedQueue.pop>("input0"); - input0.addMemRef(input0, input1); - MemRef resultContainer({1, SubMaxTokenLength, HiddenSize}); - _mlir_ciface_forward3(&resultContainer, &input0, &input2); - std::cout << "forward3 computed." << std::endl; - sharedQueue.push("output", resultContainer); + void run() { + while (true) { + auto input1 = sharedQueue.pop>("input1"); + auto input2 = sharedQueue.pop>("input"); + auto input0 = sharedQueue.pop>("input0"); + input0.addMemRef(input0, input1); + MemRef resultContainer({1, SubMaxTokenLength, HiddenSize}); + _mlir_ciface_forward3(&resultContainer, &input0, &input2); + std::cout << "forward3 computed." << std::endl; + sharedQueue.push("output", resultContainer); + } } - } -private: - AddQueue &sharedQueue; -}; + private: + AddQueue &sharedQueue; + }; #endif // LLAMAAdd_H diff --git a/examples/SplitLlama/llamaInput.h b/examples/SplitLlama/llamaInput.h index 86312b57..f665fc02 100644 --- a/examples/SplitLlama/llamaInput.h +++ b/examples/SplitLlama/llamaInput.h @@ -13,11 +13,12 @@ // limitations under the License. // //===----------------------------------------------------------------------===// +#include "BaseDisModel.h" +#include "SharedQueueTemp.h" #include #include #include #include -#include "BaseDisModel.h" #include #include #include @@ -36,7 +37,6 @@ #include #include #include -#include "SharedQueueTemp.h" using namespace buddy; using websocketpp::lib::bind; @@ -81,27 +81,6 @@ void getUserInput(std::string &inputStr) { /// Print [Log] label in bold blue format. void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } -/// Print information for each iteration. -void printIterInfo(size_t iterIdx, std::string str) { - std::cout << "\033[32;1m[Iteration " << iterIdx << "] \033[0m"; - std::cout << "Token: " << str << std::endl; -} - -/// Tokenize input data in the container. -void tokenizeInput(const std::string &vocabFile, - Text &inputContainer) { - printLogLabel(); - std::cout << "Vocab file: " << std::filesystem::canonical(vocabFile) - << std::endl; - const auto buddyTokenizeStart = std::chrono::high_resolution_clock::now(); - inputContainer.tokenizeLlama(vocabFile, MaxTokenLength); - const auto buddyTokenizeEnd = std::chrono::high_resolution_clock::now(); - const std::chrono::duration buddyTokenizeTime = - buddyTokenizeEnd - buddyTokenizeStart; - printLogLabel(); - std::cout << "Tokenize time: " << buddyTokenizeTime.count() << "ms" - << std::endl; -} //--------------------- InputMess (主线程) --------------------- class InputQueue : public SharedQueueTemp { @@ -142,7 +121,7 @@ class InputMess { // 新增:启动输出监听线程,向RMSMess发送数据 std::thread output_thread([this]() { while (true) { - resultContainerPtr = shared_queue.pop("output"); + resultContainerPtr = shared_queue.pop("output"); memRef3D0 = resultContainerPtr->memRef3D0; memRef2D = resultContainerPtr->memRef2D; memRef3D1 = resultContainerPtr->memRef3D1; @@ -151,27 +130,16 @@ class InputMess { subResultContainer1, 1, 20); std::lock_guard lock(symbolMutex); // 加锁保护符号表 - auto it = hdlsSymbol.find("RMSMess0"); - if (it != hdlsSymbol.end()) { - try { - send_data(hdlsSymbol["RMSMess0"], dataId++, - {subResultContainer0.getDataVector()}); - send_data(hdlsSymbol["RMSMess1"], dataId++, - {subResultContainer1.getDataVector()}); - std::cout << "成功向RMSMess发送数据" << std::endl; - send_data(hdlsSymbol["MHAMess0"], dataId++, - {memRef2D.getDataVector(), memRef3D1.getDataVector(), - memRef3D2.getDataVector()}); - send_data(hdlsSymbol["MHAMess1"], dataId++, - {memRef2D.getDataVector(), memRef3D1.getDataVector(), - memRef3D2.getDataVector()}); - std::cout << "成功向MHAMess发送数据" << std::endl; - } catch (const websocketpp::exception &e) { - std::cout << "转发失败: " << e.what() << std::endl; - } - } else { - std::cout << "RMSMess未连接, 丢弃结果: " << "result" << std::endl; - } + std::map>> sendMap = { + {"RMSMess0", {subResultContainer0.getDataVector()}}, + {"RMSMess1", {subResultContainer1.getDataVector()}}, + {"MHAMess0", + {memRef2D.getDataVector(), memRef3D1.getDataVector(), + memRef3D2.getDataVector()}}, + {"MHAMess1", + {memRef2D.getDataVector(), memRef3D1.getDataVector(), + memRef3D2.getDataVector()}}}; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, inputServer); } }); @@ -204,36 +172,6 @@ class InputMess { // 确保对dataId的操作是​​原子​​的 std::atomic dataId; - void send_data(websocketpp::connection_hdl hdl, uint32_t dataId, - const std::vector> &data) { - const uint8_t total = data.size(); - - if (inputServer.get_con_from_hdl(hdl)->get_state() != - websocketpp::session::state::open) - return; - - for (uint8_t i = 0; i < total; ++i) { - const auto &subdata = data[i]; - - // 构造协议头 - std::vector packet(10); // 4+1+1+2=8字节头 - memcpy(packet.data(), &dataId, 4); - packet[4] = total; - packet[5] = i; - uint32_t num = subdata.size(); - memcpy(packet.data() + 6, &num, 4); - - // 添加浮点数据 - const uint8_t *binaryData = - reinterpret_cast(subdata.data()); - packet.insert(packet.end(), binaryData, - binaryData + subdata.size() * sizeof(float)); - - inputServer.send(hdl, packet.data(), packet.size(), - websocketpp::frame::opcode::binary); - } - } - void on_server_message(websocketpp::connection_hdl hdl, server::message_ptr msg) { std::string payload = msg->get_payload(); @@ -254,23 +192,16 @@ class InputMess { getUserInput(inputStr); // 创建并tokenize输入容器 inputContainer = Text(inputStr); - tokenizeInput(vocabDir, inputContainer); + BaseDisModel::tokenizeInput(vocabDir, inputContainer, MaxTokenLength); // 将输入压入队列 shared_queue.push("input", inputContainer); + // 发送 token 数量到 Output 模块 int tokenCnt = inputContainer.getTokenCnt(); inputServer.send(hdl, std::to_string(tokenCnt), websocketpp::frame::opcode::text); - } else { - // 获取客户端类型 - int maxIndex = std::stoi(payload); - // Determine the generated token. - int tokenIndex = inputContainer.getTokenCnt() - 1; - std::string tok = inputContainer.getStr(maxIndex); - printIterInfo(tokenIndex, tok); - - // Append the generated token into the input and output container. - inputContainer.appendTokenIdx(maxIndex); + } else { + BaseDisModel::appendToken(inputContainer, payload); shared_queue.push("input", inputContainer); } } @@ -281,9 +212,7 @@ class Comp { public: Comp(InputQueue &queue, MemRefContainer *resultContainerPtr) : shared_queue(queue), resultContainerPtr(resultContainerPtr), - paramsContainer({ParamSize}) { - - } + paramsContainer({ParamSize}) {} void init() { loadAllParameters(); } void run() { while (true) { diff --git a/examples/SplitLlama/llamaMHA.h b/examples/SplitLlama/llamaMHA.h index e14d5dd5..929175dc 100644 --- a/examples/SplitLlama/llamaMHA.h +++ b/examples/SplitLlama/llamaMHA.h @@ -1,10 +1,11 @@ #ifndef LLAMAMHA_H // 作用:防止llamaMHA.h被重复引用 #define LLAMAMHA_H +#include "BaseDisModel.h" +#include "SharedQueueTemp.h" #include #include #include #include -#include "BaseDisModel.h" #include #include #include @@ -23,7 +24,6 @@ #include #include #include -#include "SharedQueueTemp.h" using namespace buddy; using websocketpp::lib::bind; @@ -46,9 +46,9 @@ extern "C" void _mlir_ciface_forward2(MemRef *, MemRef *, MemRef *, MemRef *); //--------------------- MHAMess (主线程) --------------------- -class MHAQueue : public SharedQueueTemp{ +class MHAQueue : public SharedQueueTemp { public: - MHAQueue() : SharedQueueTemp({"input", "input0", "input1", "output"}){} + MHAQueue() : SharedQueueTemp({"input", "input0", "input1", "output"}) {} }; class MHAMess { @@ -146,17 +146,12 @@ class MHAMess { MemRef subResultContainer1({SubMaxTokenLength, HiddenSize}); resultContainer.splitMemRef(std::move(resultContainer), subResultContainer0, subResultContainer1, 0, - 20); - auto it = hdlsSymbol.find("AddMess0"); - if (it != hdlsSymbol.end()) { - send_data(hdlsSymbol["AddMess0"], dataId++, - {subResultContainer0.getDataVector()}); - send_data(hdlsSymbol["AddMess1"], dataId++, - {subResultContainer1.getDataVector()}); - std::cout << "转发成功" << std::endl; - } else { - std::cout << "AddMess0未连接, 丢弃结果: " << "result" << std::endl; - } + 20); + std::map>> sendMap = { + {"AddMess0", {subResultContainer0.getDataVector()}}, + {"AddMess1", {subResultContainer1.getDataVector()}} + }; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, mhaServer); } }); inputClient_thread.join(); @@ -191,36 +186,6 @@ class MHAMess { // 表示最近从其他服务器得到的数据块在数据组内的序号 uint8_t currentSequence; - void send_data(websocketpp::connection_hdl hdl, uint32_t dataId, - const std::vector> &data) { - const uint8_t total = data.size(); - - if (mhaServer.get_con_from_hdl(hdl)->get_state() != - websocketpp::session::state::open) - return; - - for (uint8_t i = 0; i < total; ++i) { - const auto &subdata = data[i]; - - // 构造协议头 - std::vector packet(10); // 4+1+1+4=10字节头 - memcpy(packet.data(), &dataId, 4); - packet[4] = total; - packet[5] = i; - uint32_t num = subdata.size(); - memcpy(packet.data() + 6, &num, 4); - - // 添加浮点数据 - const uint8_t *binaryData = - reinterpret_cast(subdata.data()); - packet.insert(packet.end(), binaryData, - binaryData + subdata.size() * sizeof(float)); - - mhaServer.send(hdl, packet.data(), packet.size(), - websocketpp::frame::opcode::binary); - } - } - std::vector getFloatData(client::message_ptr msg) { if (msg->get_opcode() != websocketpp::frame::opcode::binary) { std::cout << "忽略非二进制消息" << std::endl; @@ -341,7 +306,7 @@ class Comp { MemRef rmsInput1 = sharedQueue.pop>("input1"); MemRef input0({1, MaxTokenLength, HiddenSize}); input0.concatenateMemRefs(rmsInput0, rmsInput1, input0, 1); - MemRef resultContainer({MaxTokenLength, HiddenSize}); + MemRef resultContainer({MaxTokenLength, HiddenSize}); _mlir_ciface_forward2(&resultContainer, ¶msContainers[index], &input0, ¤tInput2, ¤tInput3, ¤tInput1); std::cout << "第" << index << "次forward2 computed." << std::endl; @@ -385,8 +350,8 @@ class Comp { 0, 4096, 33554432, 0, 4096, 67633152, 0, 4096, 33554432, 0, 4096, 67633152, 0, 131076096}; size_t group_len = sizeof(paramSize_group) / sizeof(paramSize_group[0]); - BaseDisModel::getParameters(paramSize_group, group_len, 33554432, - splitNum, paramsContainers); + BaseDisModel::getParameters(paramSize_group, group_len, 33554432, splitNum, + paramsContainers); } void updateParams() { diff --git a/examples/SplitLlama/llamaMLP.h b/examples/SplitLlama/llamaMLP.h index 5294a451..ac306a9f 100644 --- a/examples/SplitLlama/llamaMLP.h +++ b/examples/SplitLlama/llamaMLP.h @@ -1,10 +1,11 @@ #ifndef LLAMAMLP_H // 作用:防止llamaMLP.h被重复引用 #define LLAMAMLP_H +#include "BaseDisModel.h" +#include "SharedQueueTemp.h" #include #include #include #include -#include "BaseDisModel.h" #include #include #include @@ -23,7 +24,6 @@ #include #include #include -#include "SharedQueueTemp.h" using namespace buddy; using websocketpp::lib::bind; @@ -48,7 +48,7 @@ extern "C" void _mlir_ciface_forward5(MemRef *, MemRef *, class MLPQueue : public SharedQueueTemp { public: - MLPQueue() : SharedQueueTemp({"input0", "input1", "output"}){} + MLPQueue() : SharedQueueTemp({"input0", "input1", "output"}) {} }; class MLPMess { @@ -141,16 +141,10 @@ class MLPMess { resultContainer.splitMemRef(std::move(resultContainer), subResultContainer0, subResultContainer1, 0, 20); - auto it = hdlsSymbol.find("AddMess0"); - if (it != hdlsSymbol.end()) { - send_data(hdlsSymbol["AddMess0"], dataId++, - {subResultContainer0.getDataVector()}); - send_data(hdlsSymbol["AddMess1"], dataId++, - {subResultContainer1.getDataVector()}); - std::cout << "转发成功." << std::endl; - } else { - std::cout << "AddMess0未连接, 丢弃结果." << std::endl; - } + std::map>> sendMap = { + {"AddMess0", {subResultContainer0.getDataVector()}}, + {"AddMess1", {subResultContainer1.getDataVector()}}}; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, mlpServer); } }); rmsClient_thread.join(); @@ -175,38 +169,6 @@ class MLPMess { std::atomic dataId; std::mutex dataMutex; - void send_data(websocketpp::connection_hdl hdl, uint32_t dataId, - const std::vector> &data) { - const uint8_t total = data.size(); - - auto con = mlpServer.get_con_from_hdl(hdl); - if (!con || con->get_state() != websocketpp::session::state::open) { - std::cerr << "连接已关闭,无法发送数据。" << std::endl; - return; - } - - for (uint8_t i = 0; i < total; ++i) { - const auto &subdata = data[i]; - - // 构造协议头 - std::vector packet(10); // 4+1+1+2=8字节头 - memcpy(packet.data(), &dataId, 4); - packet[4] = total; - packet[5] = i; - uint32_t num = subdata.size(); - memcpy(packet.data() + 6, &num, 4); - - // 添加浮点数据 - const uint8_t *binaryData = - reinterpret_cast(subdata.data()); - packet.insert(packet.end(), binaryData, - binaryData + subdata.size() * sizeof(float)); - - mlpServer.send(hdl, packet.data(), packet.size(), - websocketpp::frame::opcode::binary); - } - } - void on_server_message(websocketpp::connection_hdl hdl, server::message_ptr msg) { std::string payload = msg->get_payload(); @@ -289,7 +251,8 @@ class MLPMess { //------------------------------------------------------------------------------ class Comp { public: - Comp(MLPQueue &queue, const std::string splitNum = "0") : sharedQueue(queue), splitNum(splitNum) {} + Comp(MLPQueue &queue, const std::string splitNum = "0") + : sharedQueue(queue), splitNum(splitNum) {} void init() { loadAllParameters(); } @@ -303,7 +266,7 @@ class Comp { _mlir_ciface_forward5(&resultContainer, ¶msContainers[index], &input0); std::cout << "第" << index << "次forward5 computed." << std::endl; - sharedQueue.push("output",resultContainer); + sharedQueue.push("output", resultContainer); index = (index + 1) % 32; } } @@ -339,20 +302,21 @@ class Comp { 0, 4096, 33554432, 0, 4096, 67633152, 0, 4096, 33554432, 0, 4096, 67633152, 0, 131076096}; size_t group_len = sizeof(paramSize_group) / sizeof(paramSize_group[0]); - BaseDisModel::getParameters(paramSize_group, group_len, 67633152, - splitNum, paramsContainers); - // /// Define directories of vacabulary and parameter file. - // std::string llamaBuildDir = LLAMA_EXAMPLE_BUILD_PATH; - - // for (int i = 0; i < 194; i++) { // N 为需要生成的数量 - // if (paramSize_group[i] == 67633152) { - // std::string paramsDir = llamaBuildDir + "/subgraph" + - // std::to_string(i) + "_arg" + splitNum + ".data"; - // MemRef paramsContainer({paramSize_group[i]}); - // loadParameters(paramsDir, paramsContainer); - // paramsContainers.push_back(std::move(paramsContainer)); - // } - // } + BaseDisModel::getParameters(paramSize_group, group_len, 67633152, splitNum, + paramsContainers); + // /// Define directories of vacabulary and parameter file. + // std::string llamaBuildDir = LLAMA_EXAMPLE_BUILD_PATH; + + // for (int i = 0; i < 194; i++) { // N 为需要生成的数量 + // if (paramSize_group[i] == 67633152) { + // std::string paramsDir = llamaBuildDir + "/subgraph" + + // std::to_string(i) + "_arg" + splitNum + + // ".data"; + // MemRef paramsContainer({paramSize_group[i]}); + // loadParameters(paramsDir, paramsContainer); + // paramsContainers.push_back(std::move(paramsContainer)); + // } + // } } }; diff --git a/examples/SplitLlama/llamaOutput.h b/examples/SplitLlama/llamaOutput.h index c9fde1ac..0ff74693 100644 --- a/examples/SplitLlama/llamaOutput.h +++ b/examples/SplitLlama/llamaOutput.h @@ -1,6 +1,7 @@ #ifndef LLAMAOUTPUT_H // 作用:防止llamaOutput.h被重复引用 #define LLAMAOUTPUT_H #include "BaseDisModel.h" +#include "SharedQueueTemp.h" #include #include #include @@ -23,7 +24,6 @@ #include #include #include -#include "SharedQueueTemp.h" using namespace buddy; using websocketpp::lib::bind; @@ -40,6 +40,7 @@ constexpr size_t HiddenSize = 4096; constexpr size_t HiddenSize0 = 128; constexpr size_t HiddenSize1 = 41; constexpr size_t ParamSize = 131076096; +constexpr size_t separatorTokenIndex = 2; /// Print [Log] label in bold blue format. void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } @@ -48,12 +49,6 @@ void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } extern "C" void _mlir_ciface_forward193(MemRef *, MemRef *, MemRef *); -/// Find the index of the max value. -int findMaxIndex(const float *start, const float *end) { - return std::distance(start, std::max_element(start, end)); -} - - //--------------------- OutputMess (主线程) --------------------- class OutputQueue : public SharedQueueTemp { public: @@ -119,28 +114,18 @@ class OutputMess { while (true) { resultContainer = sharedQueue.pop>("output"); std::lock_guard lock(hdlMutex); // 加锁保护符号表 - int tokenIndex = currentToken + tokenCnt - 1; - currentToken++; - // Determine the generated token. - const float *startPtr = - resultContainer.getData() + tokenIndex * MaxVocabSize; - const float *endPtr = startPtr + MaxVocabSize; - int maxIndex = findMaxIndex(startPtr, endPtr); - // Stop if a separator token (2, ) or line break token (13 <0x0A>) - // is generated. - if (maxIndex == 2) { - break; - } - // Append the generated token into the input and output container. - outputContainer.appendTokenIdx(maxIndex); + int generate = BaseDisModel::generatedToken(resultContainer, outputContainer, + currentToken, tokenCnt, MaxVocabSize, separatorTokenIndex); + if (generate < 0) break; + if (currentToken == (MaxTokenLength - tokenCnt)) { std::cout << "\033[33;1m[Output]\033[0m " << outputContainer.revertLlama() << std::endl; currentToken = 0; } else if (currentToken < (MaxTokenLength - tokenCnt)) { - inputClient.send(inputHdl, std::to_string(maxIndex), + inputClient.send(inputHdl, std::to_string(generate), websocketpp::frame::opcode::text); std::cout << "第" << currentToken << "次Token推理完成." << std::endl; } else { diff --git a/examples/SplitLlama/llamaRMS.h b/examples/SplitLlama/llamaRMS.h index 0f7b175c..090dccb6 100644 --- a/examples/SplitLlama/llamaRMS.h +++ b/examples/SplitLlama/llamaRMS.h @@ -1,10 +1,11 @@ #ifndef LLAMARMS_H // 作用:防止llamaRMS.h被重复引用 #define LLAMARMS_H +#include "BaseDisModel.h" +#include "SharedQueueTemp.h" #include #include #include #include -#include "BaseDisModel.h" #include #include #include @@ -23,7 +24,6 @@ #include #include #include -#include "SharedQueueTemp.h" using namespace buddy; using websocketpp::lib::bind; @@ -44,23 +44,19 @@ constexpr size_t HiddenSize1 = 41; extern "C" void _mlir_ciface_forward1(MemRef *, MemRef *, MemRef *); - //--------------------- RMSMess (主线程) --------------------- class RMSQueue : public SharedQueueTemp { public: - RMSQueue() : SharedQueueTemp({"input", "output"}) {} - }; class RMSMess { public: RMSMess(const std::string name, RMSQueue &queue, const uint16_t &port, const std::string &uri) - : queue(queue), rmsServer(), name(name), inputClient(), - hdlsSymbol(), + : queue(queue), rmsServer(), name(name), inputClient(), hdlsSymbol(), resultContainer(MemRef({1, SubMaxTokenLength, HiddenSize})), dataId(0) { /// 服务器初始化 @@ -112,22 +108,12 @@ class RMSMess { while (true) { resultContainer = queue.pop>("output"); std::lock_guard lock(symbolMutex); // 加锁保护符号表 - if (hdlsSymbol.find("MHAMess0") != hdlsSymbol.end()) { - send_data(hdlsSymbol["MHAMess0"], dataId++, - {resultContainer.getDataVector()}); - send_data(hdlsSymbol["MHAMess1"], dataId++, - {resultContainer.getDataVector()}); - std::cout << name << "转发" << "MHAMess" << "成功" << std::endl; - } else if (hdlsSymbol.find("MLPMess0") != hdlsSymbol.end()) { - send_data(hdlsSymbol["MLPMess0"], dataId++, - {resultContainer.getDataVector()}); - send_data(hdlsSymbol["MLPMess1"], dataId++, - {resultContainer.getDataVector()}); - std::cout << name << "转发" << "MLPMess" << "成功" << std::endl; - } else { - std::cout << "MHAMess0或MLPMess0未连接, 丢弃结果: " << "result" - << std::endl; - } + std::map>> sendMap = { + {"MHAMess0", {resultContainer.getDataVector()}}, + {"MHAMess1", {resultContainer.getDataVector()}}, + {"MLPMess0", {resultContainer.getDataVector()}}, + {"MLPMess1", {resultContainer.getDataVector()}}}; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, rmsServer); } }); @@ -157,36 +143,6 @@ class RMSMess { // 是否是第一个rms模块 bool isFirst; - void send_data(websocketpp::connection_hdl hdl, uint32_t dataId, - const std::vector> &data) { - const uint8_t total = data.size(); - - if (rmsServer.get_con_from_hdl(hdl)->get_state() != - websocketpp::session::state::open) - return; - - for (uint8_t i = 0; i < total; ++i) { - const auto &subdata = data[i]; - - // 构造协议头 - std::vector packet(10); // 4+1+1+2=8字节头 - memcpy(packet.data(), &dataId, 4); - packet[4] = total; - packet[5] = i; - uint32_t num = subdata.size(); - memcpy(packet.data() + 6, &num, 4); - - // 添加浮点数据 - const uint8_t *binaryData = - reinterpret_cast(subdata.data()); - packet.insert(packet.end(), binaryData, - binaryData + subdata.size() * sizeof(float)); - - rmsServer.send(hdl, packet.data(), packet.size(), - websocketpp::frame::opcode::binary); - } - } - std::vector getFloatData(client::message_ptr msg) { if (msg->get_opcode() != websocketpp::frame::opcode::binary) { std::cout << "忽略非二进制消息" << std::endl; @@ -262,14 +218,9 @@ class RMSMess { std::cout << "接收到AddMess数据" << std::endl; { std::lock_guard lockMutex(symbolMutex); // 加锁保护符号表 - auto it = hdlsSymbol.find("AddMess"); - if (it != hdlsSymbol.end()) { - send_data(hdlsSymbol["AddMess"], dataId++, - {subResultContainer.getDataVector()}); - std::cout << name << "转发AddMess成功." << std::endl; - } else { - std::cout << "AddMess未连接, 丢弃结果." << std::endl; - } + std::map>> sendMap = { + {"AddMess", {subResultContainer.getDataVector()}}}; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, rmsServer); } queue.push>("input", subResultContainer); } @@ -284,14 +235,9 @@ class RMSMess { std::cout << "接收到InputMess数据" << std::endl; { std::lock_guard lockMutex(symbolMutex); // 加锁保护符号表 - auto it = hdlsSymbol.find("AddMess"); - if (it != hdlsSymbol.end()) { - send_data(hdlsSymbol["AddMess"], dataId++, - {subResultContainer.getDataVector()}); - std::cout << name << "转发AddMess成功." << std::endl; - } else { - std::cout << "AddMess未连接, 丢弃结果." << std::endl; - } + std::map>> sendMap = { + {"AddMess", {subResultContainer.getDataVector()}}}; + BaseDisModel::sendToClient(sendMap, hdlsSymbol, dataId, rmsServer); } queue.push>("input", subResultContainer); } @@ -304,7 +250,7 @@ class Comp { public: Comp(RMSQueue &queue, const int rmsNum) : queue(queue), rmsNum(rmsNum) {} - void init() {loadAllParameters();} + void init() { loadAllParameters(); } void run() { while (true) { @@ -323,7 +269,7 @@ class Comp { uint32_t index = 0; const int rmsNum; - void loadAllParameters(){ + void loadAllParameters() { constexpr size_t paramSize_group[] = { 131072064, 4096, 33554432, 0, 4096, 67633152, 0, 4096, 33554432, 0, 4096, 67633152, 0, 4096, 33554432, 0, 4096, 67633152,