diff --git a/fserver/csrc/public.hpp b/fserver/csrc/public.hpp index 59d2b92..7299ef9 100644 --- a/fserver/csrc/public.hpp +++ b/fserver/csrc/public.hpp @@ -55,12 +55,11 @@ void RequestHandler(const AFTensorMeta& req_meta, AFTensorServer* server) { std::lock_guard lock(mu_); meta_map_[handler_counter_] = req_meta; - q_[req_meta.sender_rank].emplace_back(handler_counter_, - std::move(tensors), + q_[req_meta.sender_rank].emplace_back(handler_counter_, std::move(tensors), keys); q_signal_.fetch_or(1 << req_meta.sender_rank); + ++handler_counter_; } - ++handler_counter_; } std::vector get_batch() { @@ -183,10 +182,10 @@ void init() { q_signal_.store(0); ps::StartPS(0, role_, group_size_ * node_rank_ + gpu_ + offset, true); if (role_ == Node::WORKER) { - fworker_ = new AFTensorWorker(instance_id_ ); + fworker_ = new AFTensorWorker(instance_id_); barrier(true, true); } else if (role_ == Node::SERVER) { - fserver_ = new AFTensorServer(instance_id_ ); + fserver_ = new AFTensorServer(instance_id_); fserver_->SetRequestHandle(RequestHandler); ps::RegisterExitCallback([]() { delete fserver_; }); barrier(true, true); diff --git a/include/ps/internal/backend.h b/include/ps/internal/backend.h index 1c84aa0..63dfb39 100644 --- a/include/ps/internal/backend.h +++ b/include/ps/internal/backend.h @@ -109,6 +109,9 @@ class Backend { static Backend* backend_impl = nullptr; if (backend_impl == nullptr) { std::unique_lock lock(backends_mutex_); + if (backend_impl) { + return backend_impl; + } std::string backend_type = "GPU"; backend_type = Environment::Get()->find("STEPMESH_BAKCEND", backend_type); PS_CHECK_NE(backends_.find(backend_type), backends_.end())