diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index d0d64a25..4a34c464 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -64,6 +64,7 @@ DEFINE_int32( DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); +DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -187,15 +188,35 @@ void Train(const nn::parallel::Rank &rank) { LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported."; } - // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions - // before wrapping the model with DistributedDataParallel (DDP). - // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors - // are created during the conversion. - if (ddp_world_size > 1) { + auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); + + // TODO(dcj): support more complex optimizer later + auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); + + if (pp_world_size > 1) { + // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct + // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. + auto shapes = std::vector>{ + {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; + + model = std::make_shared( + model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer), + rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); + if (ddp_world_size > 1) { + auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); + for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { + (*mutable_chunks)[chunk_id] + = std::make_shared(mutable_chunks->at(chunk_id), rank.thread_rank()); + } + } + } else if (ddp_world_size > 1) { + // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions + // before wrapping the model with DistributedDataParallel (DDP). + // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors + // are created during the conversion. model = std::make_shared(model, rank.thread_rank()); } - auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size, ddp_rank, ddp_world_size); @@ -216,9 +237,6 @@ void Train(const nn::parallel::Rank &rank) { tokenizer = std::make_unique(FLAGS_tokenizer_bin); } - // TODO(dcj): support more complex optimizer later - auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); - auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( @@ -227,17 +245,6 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; - if (pp_world_size > 1) { - // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct - // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. - auto shapes = std::vector>{ - {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; - - model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, - pp_rank, std::make_shared(optimizer), - rank.thread_rank()); - } - LOG(INFO) << "start training"; for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { @@ -293,6 +300,7 @@ void Train(const nn::parallel::Rank &rank) { auto logits = model->Forward({x, y})[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward"; auto loss = loss_fn->Forward({logits, y})[0]; + // FIXME(jym): verify gradient accumulation precision loss = loss / grad_accum_steps; // disable autocast for the current step (backward is not under autocast) @@ -356,7 +364,7 @@ int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel); + FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index df7fbead..69e4278e 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -4,9 +4,12 @@ #include #include #include +#include +#include #include #include #include +#include #include "glog/logging.h" @@ -175,116 +178,149 @@ Block::Forward(const std::vector> &x) { return {x2}; } -GPT2::GPT2(const GPT2Config &config) : config_(config) { - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, nn::parallel::pp_rank); +GPT2FirstStage::GPT2FirstStage(const GPT2Config &config) : config_(config) { + modules_[kWTELayerName] = std::make_shared( + config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); + modules_[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); +} + +std::vector> +GPT2FirstStage::Forward(const std::vector> &input) { + // (B, T) + auto x1 = input[0]; + CHECK_LE(x1->Dims()[1], config_.block_size) + << "Cannot forward sequence of length " << x1->Dims()[1] << ", block size is only " << config_.block_size; + const auto device = x1->GetDevice(); + // (T_local) + // NOTE(zbl): Slice pos sequence when SP is enabled + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); + int tp_rank = 0; + if (tp_world_size > 1) { + auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( + nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank())); + tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank()); + } + int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1]; + int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; + auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); + + // (B, T) -> Embedding(V_local, C) -> (B, T, C) + auto tok_emb = modules_[kWTELayerName]->Forward({x1})[0]; + + // (T) -> Embedding(T_max, C) -> (T, C) + auto pos_emb = modules_[kWPELayerName]->Forward({pos})[0]; + // (B, T, C) + return {tok_emb + pos_emb}; +} + +GPT2Chunk::GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer) : config_(config) { + std::vector> h; + for (int64_t i = start_layer; i < end_layer; ++i) { + auto layer = std::make_shared(config); + h.push_back(layer); + } + modules_[kHLayerName] = std::make_shared(std::move(h)); +} + +std::vector> +GPT2Chunk::Forward(const std::vector> &x) { + auto x1 = x[0]; + // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) + for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = h->Forward({x1})[0]; } + return {x1}; +} + +GPT2LastStage::GPT2LastStage(const GPT2Config &config) : config_(config) { + modules_[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); + // don't init this one, we will tie weights + modules_[kLMHeadLayerName] = std::make_shared( + /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, + /*bias=*/false, + // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); +} + +std::vector> +GPT2LastStage::Forward(const std::vector> &x) { + // (B, T, C) -> Layernorm -> (B, T, C) + auto x1 = modules_[kLnFLayerName]->Forward(x); + + // TODO(dcj): add inference-time mini-optimization + // (B, T, C) -> Linear(C, V) -> (B, T, V) + return modules_[kLMHeadLayerName]->Forward(x1); +} + +GPT2::GPT2(const GPT2Config &config) + : config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, + nn::parallel::global::GetVirtualPipelineParallelSize())) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); // NOTE(zbl): VocabParallelEmbedding requires vocab_size % tp_size == 0 // Megatron-LM has an optional argument `--make-vocab-size-divisible-by`, would do padding to vocab // Here we introduce padding by default, might need modify Tokenizer correspondingly later CHECK_EQ(config.vocab_size % tp_world_size, 0) << "Vocab size should be divisible by TP world size"; - { - std::unordered_map> transformer; - if (is_first_stage) { - transformer[kWTELayerName] = std::make_shared( - config_.vocab_size, config_.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - transformer[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); - } - { - std::vector> h; - for (int64_t i = start_layer; i < end_layer; ++i) { h.push_back(std::make_shared(config_)); } - transformer[kHLayerName] = std::make_shared(std::move(h)); - if (is_last_stage) { - transformer[kLnFLayerName] = std::make_shared(std::vector{config_.n_embd}); - } + std::unordered_map> transformer; + if (stage_info_.is_first_stage) { + modules_[kPPFirstStageName] = std::make_shared(config_); + transformer[GPT2FirstStage::kWTELayerName] + = modules_[kPPFirstStageName]->mutable_module(GPT2FirstStage::kWTELayerName); + transformer[GPT2FirstStage::kWPELayerName] + = modules_[kPPFirstStageName]->mutable_module(GPT2FirstStage::kWPELayerName); + } - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + { + std::map>> start_layer_to_layer_size_and_chunk; + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + const auto [start_layer, end_layer] = stage_info_.layer_ranges_per_chunk[chunk_idx]; + auto chunk = std::make_shared(config_, start_layer, end_layer); + start_layer_to_layer_size_and_chunk[start_layer] = std::make_pair(end_layer - start_layer, chunk); } - if (is_last_stage) { - // don't init this one, we will tie weights - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each tp_rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + std::vector> h; + int chunk_idx = 0; + for (auto &[start_layer, layer_size_and_chunk] : start_layer_to_layer_size_and_chunk) { + auto [layer_size, chunk] = layer_size_and_chunk; + for (int idx = 0; idx < layer_size; ++idx) { + h.push_back(chunk->mutable_module(GPT2Chunk::kHLayerName)->mutable_module(std::to_string(idx))); + } + modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)] = std::move(chunk); + ++chunk_idx; } + transformer[GPT2Chunk::kHLayerName] = std::make_shared(std::move(h)); + } - // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation - if (pp_size == 1) { - // https://paperswithcode.com/method/weight-tying - *mutable_module(kTransformerLayerName) - ->mutable_module(kWTELayerName) - ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) - = module(kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); - } + if (stage_info_.is_last_stage) { + modules_[kPPLastStageName] = std::make_shared(config_); + transformer[GPT2LastStage::kLnFLayerName] + = modules_[kPPLastStageName]->mutable_module(GPT2LastStage::kLnFLayerName); + modules_[GPT2LastStage::kLMHeadLayerName] + = modules_[kPPLastStageName]->mutable_module(GPT2LastStage::kLMHeadLayerName); + } + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + + // FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation + if (nn::parallel::global::GetPipelineParallelSize() == 1) { + // https://paperswithcode.com/method/weight-tying + *mutable_module(kTransformerLayerName) + ->mutable_module(GPT2FirstStage::kWTELayerName) + ->mutable_parameter(nn::parallel::VocabParallelEmbedding::kParamWeightName) + = module(GPT2LastStage::kLMHeadLayerName).parameter(nn::parallel::ColumnParallelLinear::kParamWeightName); } } std::vector> GPT2::Forward(const std::vector> &x) { - int pp_rank = nn::parallel::pp_rank; - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - bool is_first_stage = (pp_rank == 0); - bool is_last_stage = (pp_rank == pp_size - 1); - - // (B, T) - auto x1 = x[0]; - const auto device = x1->GetDevice(); - - const auto t - = x1->Dims()[1] * (is_first_stage ? 1 : nn::parallel::global::GetSequenceParallelSize()); // full_seq_len - CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " - << config_.block_size; - // forward the GPT2 model itself - auto &transformer = modules_[kTransformerLayerName]; - if (is_first_stage) { - // (T_local) - // NOTE(zbl): Slice pos sequence when SP is enabled - auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); - auto sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); - int tp_rank = 0; - if (tp_world_size > 1) { - auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get( - nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank())); - tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank()); - } - int64_t t_local = sequence_parallel_enabled ? (t / tp_world_size) : t; - int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0; - auto pos = nn::init::Arange(start, start + t_local, infini_train::DataType::kINT64, device); - - // (B, T) -> Embedding(V_local, C) -> (B, T, C) - auto tok_emb = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; - // (T) -> Embedding(T_max, C) -> (T, C) - auto pos_emb = transformer->mutable_module(kWPELayerName)->Forward({pos})[0]; - // (B, T, C) - x1 = tok_emb + pos_emb; - } - - // (B, T, C) -> transformer -> (B, T, C) - auto h_modules = transformer->mutable_module(kHLayerName); - CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; - auto h_layers = std::dynamic_pointer_cast(h_modules); - // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *h_layers) { x1 = h->Forward({x1})[0]; } - - if (is_last_stage) { - // (B, T, C) -> Layernorm -> (B, T, C) - auto x3 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); - - // TODO(dcj): add inference-time mini-optimization - // (B, T, C) -> Linear(C, V) -> (B, T, V) - auto logits = modules_[kLMHeadLayerName]->Forward(x3); - // (B, T, V_original) - return logits; + auto x1 = modules_[kPPFirstStageName]->Forward(x); + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + x1 = modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)]->Forward(x1); } - return {x1}; + return modules_[kPPLastStageName]->Forward(x1); } std::shared_ptr GPT2::FromPretrained(ModelType model_type) { @@ -351,9 +387,17 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { CHECK_EQ(n_embd % n_head, 0) << "n_embd must be divisible by n_head."; CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; + // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== int pp_size = nn::parallel::global::GetPipelineParallelSize(); - auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, nn::parallel::pp_rank); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto pp_rank = nn::parallel::pp_rank; + auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, pp_rank, vpp_size); + // ========== layer to chunk ========== + std::vector owned_layers(n_layer, false); + for (const auto &[start, end] : layer_ranges_per_chunk) { + for (int i = start; i < end; ++i) { owned_layers[i] = true; } + } auto tp_rank = nn::parallel::tp_rank; // calculate xx_size_per_partition @@ -379,29 +423,29 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { // local: (vocab_size_per_partition, n_embd) if (is_first_stage) { auto &transformer_wte_weight - = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kWTELayerName, + = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, GPT2FirstStage::kWTELayerName, nn::parallel::VocabParallelEmbedding::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, v_start, vpp); } else if (pp_size > 1 && is_last_stage) { - auto &lm_head_weight = state_dict[std::format("{}.{}", GPT2::kLMHeadLayerName, + auto &lm_head_weight = state_dict[std::format("{}.{}", GPT2LastStage::kLMHeadLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(lm_head_weight->DataPtr()), model_vocab_size, n_embd, v_start, vpp); } else { - size_t wte_bytes = vocab_size * n_embd * sizeof(float); + size_t wte_bytes = model_vocab_size * n_embd * sizeof(float); ifs.seekg(wte_bytes, std::ios::cur); } if (tp_size == 1) { // Skip padded vocab part when TP is not enabled - ifs.ignore((padded_vocab_size - vocab_size) * n_embd * sizeof(float)); + ifs.ignore((padded_vocab_size - model_vocab_size) * n_embd * sizeof(float)); } if (is_first_stage) { // transformer.wpe.weight - auto &transformer_wpe_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, - GPT2::kWPELayerName, nn::Embedding::kParamWeightName)]; + auto &transformer_wpe_weight = state_dict[std::format( + "{}.{}.{}", GPT2::kTransformerLayerName, GPT2FirstStage::kWPELayerName, nn::Embedding::kParamWeightName)]; ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); } else { size_t wpe_bytes = block_size * n_embd * sizeof(float); @@ -409,13 +453,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_1.weight + int local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kLn1LayerName, + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kLn1LayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_1_w_bytes = n_embd * sizeof(float); ifs.seekg(ln_1_w_bytes, std::ios::cur); @@ -423,13 +468,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_1.bias + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kLn1LayerName, + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kLn1LayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_1_b_bytes = n_embd * sizeof(float); ifs.seekg(ln_1_b_bytes, std::ios::cur); @@ -437,12 +483,12 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_attn.weight (ColumnParallelLinear, but actually applies on "rows") + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kAttnLayerName, - CausalSelfAttention::kCAttnLayerName, + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + GPT2Chunk::kHLayerName, std::to_string(local_layer_index), + Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T @@ -471,6 +517,8 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { /*dst=*/dst + (2 * local_C) * cols_all, /*rows=*/rows_all, /*cols=*/cols_all, /*row_start=*/2 * n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + + ++local_layer_index; } else { size_t c_attn_w_bytes = qkv_out * n_embd * sizeof(float); ifs.seekg(c_attn_w_bytes, std::ios::cur); @@ -478,12 +526,12 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_attn.bias (ColumnParallelLinear) + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kAttnLayerName, - CausalSelfAttention::kCAttnLayerName, + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + GPT2Chunk::kHLayerName, std::to_string(local_layer_index), + Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] @@ -511,6 +559,8 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { /*dst=*/dst + (2 * local_C), /*len=*/len_all, /*start=*/2 * n_embd + tp_rank * local_C, /*cnt=*/local_C); + + ++local_layer_index; } else { size_t c_attn_b_bytes = qkv_out * sizeof(float); ifs.seekg(c_attn_b_bytes, std::ios::cur); @@ -518,15 +568,16 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_proj.weight (RowParallelLinear, but actually applies on "columns") + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kAttnLayerName, - CausalSelfAttention::kCProjLayerName, + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + GPT2Chunk::kHLayerName, std::to_string(local_layer_index), + Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, in_pp); + ++local_layer_index; } else { size_t c_proj_w_bytes = n_embd * n_embd * sizeof(float); ifs.seekg(c_proj_w_bytes, std::ios::cur); @@ -534,14 +585,15 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_proj.bias (RowParallelLinear, no shard on bias) + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kAttnLayerName, - CausalSelfAttention::kCProjLayerName, + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, + GPT2Chunk::kHLayerName, std::to_string(local_layer_index), + Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t c_proj_b_bytes = n_embd * sizeof(float); ifs.seekg(c_proj_b_bytes, std::ios::cur); @@ -549,13 +601,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_2.weight + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kLn2LayerName, + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kLn2LayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_2_w_bytes = n_embd * sizeof(float); ifs.seekg(ln_2_w_bytes, std::ios::cur); @@ -563,13 +616,14 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_2.bias + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, - std::to_string(idx - start_layer), Block::kLn2LayerName, + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kLn2LayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_2_b_bytes = n_embd * sizeof(float); ifs.seekg(ln_2_b_bytes, std::ios::cur); @@ -577,13 +631,15 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_fc.weight (ColumnParallelLinear, but actually applies on "rows") + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), - Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + if (owned_layers[idx]) { + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); + ++local_layer_index; } else { size_t c_fc_w_bytes = fc_out * n_embd * sizeof(float); ifs.seekg(c_fc_w_bytes, std::ios::cur); @@ -591,13 +647,15 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_fc.bias (ColumnParallelLinear) + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), - Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; + if (owned_layers[idx]) { + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamBiasName)]; ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); + ++local_layer_index; } else { size_t c_fc_b_bytes = fc_out * sizeof(float); ifs.seekg(c_fc_b_bytes, std::ios::cur); @@ -605,14 +663,16 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_proj.weight (RowParallelLinear, but actually applies on "columns") + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), - Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + if (owned_layers[idx]) { + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, in4_pp); + ++local_layer_index; } else { size_t c_proj_w_bytes = fc_out * n_embd * sizeof(float); ifs.seekg(c_proj_w_bytes, std::ios::cur); @@ -620,13 +680,15 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_proj.bias (RowParallelLinear, no shard on bias) + local_layer_index = 0; for (int idx = 0; idx < n_layer; ++idx) { - bool owned = (idx >= start_layer && idx < end_layer); - if (owned) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2::kHLayerName, std::to_string(idx - start_layer), - Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; + if (owned_layers[idx]) { + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}.{}", GPT2::kTransformerLayerName, GPT2Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t c_proj_b_bytes = n_embd * sizeof(float); ifs.seekg(c_proj_b_bytes, std::ios::cur); @@ -635,12 +697,12 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { if (is_last_stage) { // transformer.ln_f.weight - auto &transformer_ln_f_weight = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, - GPT2::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; + auto &transformer_ln_f_weight = state_dict[std::format( + "{}.{}.{}", GPT2::kTransformerLayerName, GPT2LastStage::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); // transformer.ln_f.bias - auto &transformer_ln_f_bias = state_dict[std::format("{}.{}.{}", GPT2::kTransformerLayerName, - GPT2::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; + auto &transformer_ln_f_bias = state_dict[std::format( + "{}.{}.{}", GPT2::kTransformerLayerName, GPT2LastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); } else { size_t ln_f_w_bytes = n_embd * sizeof(float); @@ -649,3 +711,5 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { } return local_gpt2; } + +int GPT2::GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); } diff --git a/example/gpt2/net.h b/example/gpt2/net.h index dba5cdfe..f52e4d4e 100644 --- a/example/gpt2/net.h +++ b/example/gpt2/net.h @@ -7,8 +7,9 @@ #include "glog/logging.h" -#include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" +#include "infini_train/include/nn/parallel/pp/pipeline_stage.h" #include "infini_train/include/tensor.h" struct GPT2Config { @@ -71,15 +72,51 @@ class Block : public infini_train::nn::CloneableModule { Forward(const std::vector> &x) override; }; -class GPT2 : public infini_train::nn::CloneableModule { +class GPT2FirstStage : public infini_train::nn::CloneableModule { public: static constexpr char kWTELayerName[] = "wte"; static constexpr char kWPELayerName[] = "wpe"; + + explicit GPT2FirstStage(const GPT2Config &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const GPT2Config config_; +}; + +class GPT2Chunk : public infini_train::nn::CloneableModule { +public: static constexpr char kHLayerName[] = "h"; + + GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const GPT2Config config_; +}; + +class GPT2LastStage : public infini_train::nn::CloneableModule { +public: static constexpr char kLnFLayerName[] = "ln_f"; - static constexpr char kTransformerLayerName[] = "transformer"; static constexpr char kLMHeadLayerName[] = "lm_head"; + explicit GPT2LastStage(const GPT2Config &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const GPT2Config config_; +}; + +class GPT2 : public infini_train::nn::CloneableModule { +public: + static constexpr char kTransformerLayerName[] = "transformer"; + enum class ModelType : int8_t { kGPT2, kGPT2Medium, @@ -95,6 +132,9 @@ class GPT2 : public infini_train::nn::CloneableModule { static std::shared_ptr FromPretrained(ModelType model_type); static std::shared_ptr FromLLMC(const std::string &filepath); + int GetChunkSize() const; + private: - GPT2Config config_; + const GPT2Config config_; + const infini_train::nn::parallel::StageInfo stage_info_; }; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 09a6a179..fdea2162 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -62,7 +62,8 @@ DEFINE_int32( "When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices."); DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size"); DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel"); -DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, , specified the number of PP stages."); +DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages."); +DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); @@ -167,16 +168,35 @@ void Train(const nn::parallel::Rank &rank) { LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported."; } - // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions - // before wrapping the model with DistributedDataParallel (DDP). - // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors - // are created during the conversion. - if (ddp_world_size > 1) { + auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); + + // TODO(dcj): support more complex optimizer later + auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); + + if (pp_world_size > 1) { + // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct + // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. + auto shapes = std::vector>{ + {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; + + model = std::make_shared( + model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer), + rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); + if (ddp_world_size > 1) { + auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); + for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { + (*mutable_chunks)[chunk_id] + = std::make_shared(mutable_chunks->at(chunk_id), rank.thread_rank()); + } + } + } else if (ddp_world_size > 1) { + // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions + // before wrapping the model with DistributedDataParallel (DDP). + // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors + // are created during the conversion. model = std::make_shared(model, rank.thread_rank()); } - auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); - DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size, ddp_rank, ddp_world_size); @@ -196,9 +216,6 @@ void Train(const nn::parallel::Rank &rank) { tokenizer = std::make_unique(FLAGS_tokenizer_bin); } - // TODO(dcj): support more complex optimizer later - auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); - auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) @@ -206,17 +223,6 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; - if (pp_world_size > 1) { - // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct - // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. - auto shapes = std::vector>{ - {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; - - model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, - pp_rank, std::make_shared(optimizer), - rank.thread_rank()); - } - for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { const bool last_step = step == FLAGS_num_iteration; @@ -270,6 +276,7 @@ void Train(const nn::parallel::Rank &rank) { auto logits = model->Forward({x, y})[0]; LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward"; auto loss = loss_fn->Forward({logits, y})[0]; + // FIXME(jym): verify gradient accumulation precision loss = loss / grad_accum_steps; // disable autocast for the current step (backward is not under autocast) @@ -333,7 +340,7 @@ int main(int argc, char *argv[]) { google::InitGoogleLogging(argv[0]); nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel, - FLAGS_pipeline_parallel); + FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel); LOG(INFO) << nn::parallel::global::ProcessGroupOverview(); diff --git a/example/llama3/net.cc b/example/llama3/net.cc index ad7ad4c7..a70a811a 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -324,50 +325,27 @@ std::vector> Block::Forward(const std::vector> transformer; - if (is_first_stage) { - transformer[kWTELayerName] = std::make_shared( - config.vocab_size, config.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); - } +LLaMA3FirstStage::LLaMA3FirstStage(const LLaMA3Config &config) : config_(config) { + modules_[LLaMA3FirstStage::kWTELayerName] = std::make_shared( + config.vocab_size, config.n_embd, nn::parallel::global::GetSequenceParallelEnabled()); +} - std::vector> h_local; - for (int64_t i = start_layer; i < end_layer; ++i) { h_local.push_back(std::make_shared(config)); } - transformer[kHLayerName] = std::make_shared(std::move(h_local)); +std::vector> LLaMA3FirstStage::Forward(const std::vector> &x) { + return modules_[LLaMA3FirstStage::kWTELayerName]->Forward(x); +} - if (is_last_stage) { - transformer[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); - // NOTE(zbl): weight-tying is possible but torch script did not do so - modules_[kLMHeadLayerName] = std::make_shared( - /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, - /*bias=*/false, - // NOTE(zbl): each rank would get sharded [B, T, V_local] as logits - /*gather_output=*/false, - /*input_is_parallel=*/false, - /*skip_bias_add=*/false, - /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); +LLaMA3Chunk::LLaMA3Chunk(const LLaMA3Config &config, int start_layer, int end_layer) : config_(config) { + std::vector> h; + for (int64_t i = start_layer; i < end_layer; ++i) { + auto layer = std::make_shared(config); + h.push_back(layer); } - modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); + modules_[LLaMA3Chunk::kHLayerName] = std::make_shared(std::move(h)); } -std::vector> LLaMA3::Forward(const std::vector> &x) { - int pp_rank = nn::parallel::pp_rank; - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - bool is_first_stage = (pp_rank == 0); - bool is_last_stage = (pp_rank == pp_size - 1); - - // (bs, seq_len) +std::vector> LLaMA3Chunk::Forward(const std::vector> &x) { auto x1 = x[0]; const auto device = x1->GetDevice(); - const auto t - = x1->Dims()[1] * (is_first_stage ? 1 : nn::parallel::global::GetSequenceParallelSize()); // full_seq_len - CHECK_LE(t, config_.block_size) << "Cannot forward sequence of length " << t << ", block size is only " - << config_.block_size; - // Init freqs_cis on device only once // TODO(zbl): consider moving this to model construction if (buffers_[kFreqsCisName] == nullptr) { @@ -375,13 +353,8 @@ std::vector> LLaMA3::Forward(const std::vector Embedding(vocab_size, n_embd) -> (bs, seq_len, n_embd) - x1 = transformer->mutable_module(kWTELayerName)->Forward({x1})[0]; - } + // TODO(dcj): check if this shape is correct + const auto t = x1->Dims()[1] * nn::parallel::global::GetSequenceParallelSize(); // full_seq_len // TODO(zbl): dynamic start_pos int64_t start_pos = 0; @@ -393,25 +366,84 @@ std::vector> LLaMA3::Forward(const std::vector start_pos_ptr = nullptr; - auto h_modules = transformer->mutable_module(kHLayerName); - CHECK_EQ(h_modules->type(), nn::ModuleList::kType) << "Failed to get ModuleList from transformer"; - auto h_layers = std::dynamic_pointer_cast(h_modules); // (bs, seq_len, n_embd) -> transformer -> (bs, seq_len, n_embd) - for (auto &h : *h_layers) { x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; } + for (auto &h : *std::dynamic_pointer_cast(modules_[LLaMA3Chunk::kHLayerName])) { + x1 = h->Forward({x1, freqs_view, start_pos_ptr, mask})[0]; + } + return {x1}; +} + +LLaMA3LastStage::LLaMA3LastStage(const LLaMA3Config &config) : config_(config) { + modules_[kLnFLayerName] = std::make_shared(config.n_embd, config.norm_eps); + // NOTE(zbl): weight-tying is possible but torch script did not do so + modules_[kLMHeadLayerName] = std::make_shared( + /*in_features=*/config_.n_embd, /*out_features=*/config_.vocab_size, + /*bias=*/false, + // NOTE(zbl): each rank would get sharded [B, T, V_local] as logits + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); +} - if (is_last_stage) { - // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) - auto x2 = transformer->mutable_module(kLnFLayerName)->Forward({x1}); +std::vector> LLaMA3LastStage::Forward(const std::vector> &x) { + // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd) + auto x1 = modules_[kLnFLayerName]->Forward(x); - // TODO(zbl): add inference-time mini-optimization - // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) - auto logits = modules_[kLMHeadLayerName]->Forward(x2); + // TODO(zbl): add inference-time mini-optimization + // (bs, seq_len, n_embd) -> Linear(n_embd, vocab_size) -> (bs, seq_len, vocab_size) + return modules_[kLMHeadLayerName]->Forward(x1); +} - // (bs, seq_len, vocab_size) - return logits; +LLaMA3::LLaMA3(const LLaMA3Config &config) + : config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( + config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, + nn::parallel::global::GetVirtualPipelineParallelSize())) { + std::unordered_map> transformer; + if (stage_info_.is_first_stage) { + modules_[kPPFirstStageName] = std::make_shared(config_); + transformer[LLaMA3FirstStage::LLaMA3FirstStage::kWTELayerName] + = modules_[kPPFirstStageName]->mutable_module(LLaMA3FirstStage::LLaMA3FirstStage::kWTELayerName); } - return {x1}; + { + std::map>> start_layer_to_layer_size_and_chunk; + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + const auto [start_layer, end_layer] = stage_info_.layer_ranges_per_chunk[chunk_idx]; + auto chunk = std::make_shared(config_, start_layer, end_layer); + start_layer_to_layer_size_and_chunk[start_layer] = std::make_pair(end_layer - start_layer, chunk); + } + std::vector> h; + int chunk_idx = 0; + for (auto &[start_layer, layer_size_and_chunk] : start_layer_to_layer_size_and_chunk) { + auto [layer_size, chunk] = layer_size_and_chunk; + for (int idx = 0; idx < layer_size; ++idx) { + h.push_back( + chunk->mutable_module(LLaMA3Chunk::LLaMA3Chunk::kHLayerName)->mutable_module(std::to_string(idx))); + } + modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)] = std::move(chunk); + ++chunk_idx; + } + transformer[LLaMA3Chunk::LLaMA3Chunk::kHLayerName] = std::make_shared(std::move(h)); + } + + if (stage_info_.is_last_stage) { + modules_[kPPLastStageName] = std::make_shared(config_); + transformer[LLaMA3LastStage::kLnFLayerName] + = modules_[kPPLastStageName]->mutable_module(LLaMA3LastStage::kLnFLayerName); + // NOTE(zbl): weight-tying is possible but torch script did not do so + modules_[LLaMA3LastStage::kLMHeadLayerName] + = modules_[kPPLastStageName]->mutable_module(LLaMA3LastStage::kLMHeadLayerName); + } + modules_[kTransformerLayerName] = std::make_shared(std::move(transformer)); +} + +std::vector> LLaMA3::Forward(const std::vector> &x) { + auto x1 = modules_[kPPFirstStageName]->Forward({x[0]}); + for (int chunk_idx = 0; chunk_idx < stage_info_.layer_ranges_per_chunk.size(); ++chunk_idx) { + x1 = modules_[kPPChunkNamePrefix + std::to_string(chunk_idx)]->Forward(x1); + } + return modules_[kPPLastStageName]->Forward(x1); } std::shared_ptr LLaMA3::FromPretrained(ModelType model_type) { @@ -465,9 +497,18 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { .use_scaled_rope = static_cast(use_scaled_rope), .norm_eps = norm_eps, .max_gen_batch_size = max_gen_bs}); + + // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== int pp_size = nn::parallel::global::GetPipelineParallelSize(); - auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, nn::parallel::pp_rank); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto pp_rank = nn::parallel::pp_rank; + auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, pp_rank, vpp_size); + // ========== layer to chunk ========== + std::vector owned_layers(n_layer, false); + for (const auto &[start, end] : layer_ranges_per_chunk) { + for (int i = start; i < end; ++i) { owned_layers[i] = true; } + } const int tp_size = nn::parallel::global::GetTensorParallelSize(); const int tp_rank = nn::parallel::tp_rank; @@ -493,6 +534,12 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { LOG(INFO) << " max_gen_bs = " << max_gen_bs; LOG(INFO) << " version_major = " << version_major; LOG(INFO) << " version_minor = " << version_minor; + + LOG(INFO) << "Pipeline Parallel Chunks:"; + for (size_t i = 0; i < layer_ranges_per_chunk.size(); ++i) { + LOG(INFO) << " Chunk " << i << ": layers " << layer_ranges_per_chunk[i].first << " to " + << layer_ranges_per_chunk[i].second; + } } const int64_t head_dim = static_cast(n_embd) / static_cast(n_head); @@ -537,7 +584,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // ========== Read Sharded Params ========== // transformer.wte.weight : (vocab_size, n_embd) -> local tp_rank: rows of [v_start : v_start+vpp) if (is_first_stage) { - auto &wte = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kWTELayerName, + auto &wte = state_dict[std::format("{}.{}.{}", kTransformerLayerName, LLaMA3FirstStage::kWTELayerName, nn::parallel::VocabParallelEmbedding::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(wte->DataPtr()), /*rows=*/vocab_size, /*cols=*/n_embd, @@ -548,14 +595,14 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_1.weight : Full version RMSNorm + int local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i - start_layer), Block::kLn1LayerName, + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kLn1LayerName, RMSNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_1_bytes = n_embd * sizeof(float); ifs.seekg(ln_1_bytes, std::ios::cur); @@ -564,11 +611,11 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // transformer.h.{i}.attn.c_attn.weight : ColumnParallelLinear, but actually applies on "rows" // W-qkv should be [Q(=n_embd) | K(=n_kv_head*head_dim) | V(=n_kv_head*head_dim)] × n_embd + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i - start_layer), Block::kAttnLayerName, + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; @@ -596,6 +643,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { /*rows=*/attn_rows_all, /*cols=*/attn_cols, /*row_start=*/q_out_rows + kv_out_rows + tp_rank * kv_local_rows, /*row_cnt=*/kv_local_rows); + ++local_layer_index; } else { size_t qkv_bytes = static_cast(attn_rows_all) * attn_cols * sizeof(float); ifs.seekg(qkv_bytes, std::ios::cur); @@ -603,16 +651,17 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.attn.c_proj.weight : RowParallelLinear, but actually applies on "columns" + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i - start_layer), Block::kAttnLayerName, + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kAttnLayerName, CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/n_embd, /*cols=*/n_embd, /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); + ++local_layer_index; } else { size_t c_proj_bytes = static_cast(n_embd) * n_embd * sizeof(float); ifs.seekg(c_proj_bytes, std::ios::cur); @@ -620,13 +669,14 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.ln_2.weight : Full version RMSNorm + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, - std::to_string(i - start_layer), Block::kLn2LayerName, + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, + std::to_string(local_layer_index), Block::kLn2LayerName, RMSNorm::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; } else { size_t ln_2_bytes = static_cast(n_embd) * sizeof(float); ifs.seekg(ln_2_bytes, std::ios::cur); @@ -634,15 +684,16 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_fc.weight : ColumnParallelLinear, but actually applies on "rows" + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + if (owned_layers[i]) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/fc_out, /*cols=*/n_embd, /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + ++local_layer_index; } else { size_t fc_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); ifs.seekg(fc_bytes, std::ios::cur); @@ -650,15 +701,16 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_fc2.weight : ColumnParallelLinear, but actually applies on "rows" + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + if (owned_layers[i]) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCFc2LayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/fc_out, /*cols=*/n_embd, /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + ++local_layer_index; } else { size_t fc2_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); ifs.seekg(fc2_bytes, std::ios::cur); @@ -666,15 +718,16 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { } // transformer.h.{i}.mlp.c_proj.weight : RowParallelLinear, but actually applies on "columns" + local_layer_index = 0; for (int i = 0; i < static_cast(n_layer); ++i) { - bool owned = (i >= start_layer && i < end_layer); - if (owned) { + if (owned_layers[i]) { auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", kTransformerLayerName, kHLayerName, std::to_string(i - start_layer), + "{}.{}.{}.{}.{}.{}", kTransformerLayerName, LLaMA3Chunk::kHLayerName, std::to_string(local_layer_index), Block::kMlpLayerName, MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), /*rows=*/n_embd, /*cols=*/fc_out, /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); + ++local_layer_index; } else { size_t c_proj_bytes = static_cast(n_embd) * ffn_hidden * sizeof(float); ifs.seekg(c_proj_bytes, std::ios::cur); @@ -685,9 +738,9 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { // lm_head.weight : (vocab_size, n_embd) -> ColumnParallelLinear, but actually applies on "rows" { if (is_last_stage) { - auto &ln_f - = state_dict[std::format("{}.{}.{}", kTransformerLayerName, kLnFLayerName, RMSNorm::kParamWeightName)]; - auto &lm_head = state_dict[std::format("{}.{}", kLMHeadLayerName, + auto &ln_f = state_dict[std::format("{}.{}.{}", kTransformerLayerName, LLaMA3LastStage::kLnFLayerName, + RMSNorm::kParamWeightName)]; + auto &lm_head = state_dict[std::format("{}.{}", LLaMA3LastStage::kLMHeadLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), diff --git a/example/llama3/net.h b/example/llama3/net.h index ec0199aa..9bd7f9da 100644 --- a/example/llama3/net.h +++ b/example/llama3/net.h @@ -10,6 +10,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/tensor.h" struct LLaMA3Config { @@ -108,15 +109,50 @@ class Block : public infini_train::nn::CloneableModule { Forward(const std::vector> &x) override; }; -class LLaMA3 : public infini_train::nn::CloneableModule { +class LLaMA3FirstStage : public infini_train::nn::CloneableModule { public: static constexpr char kWTELayerName[] = "wte"; + + explicit LLaMA3FirstStage(const LLaMA3Config &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const LLaMA3Config config_; +}; + +class LLaMA3Chunk : public infini_train::nn::CloneableModule { +public: static constexpr char kHLayerName[] = "h"; + static constexpr char kFreqsCisName[] = "freqs_cis"; + + LLaMA3Chunk(const LLaMA3Config &config, int start_layer, int end_layer); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const LLaMA3Config config_; +}; + +class LLaMA3LastStage : public infini_train::nn::CloneableModule { +public: static constexpr char kLnFLayerName[] = "ln_f"; - static constexpr char kTransformerLayerName[] = "transformer"; static constexpr char kLMHeadLayerName[] = "lm_head"; - static constexpr char kFreqsCisName[] = "freqs_cis"; + explicit LLaMA3LastStage(const LLaMA3Config &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + const LLaMA3Config config_; +}; + +class LLaMA3 : public infini_train::nn::CloneableModule { +public: + static constexpr char kTransformerLayerName[] = "transformer"; enum class ModelType : int8_t { // TODO(zbl): more model type from huggingface @@ -135,6 +171,9 @@ class LLaMA3 : public infini_train::nn::CloneableModule { static std::shared_ptr FromPretrained(ModelType model_type); static std::shared_ptr FromLLMC(const std::string &filepath); + int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); } + private: - LLaMA3Config config_; + const LLaMA3Config config_; + const infini_train::nn::parallel::StageInfo stage_info_; }; diff --git a/infini_train/include/device.h b/infini_train/include/device.h index feeaa081..36357a09 100644 --- a/infini_train/include/device.h +++ b/infini_train/include/device.h @@ -25,6 +25,8 @@ class DeviceManager; class Device { public: + virtual ~Device() = default; + DeviceType Type() const; int8_t Index() const; @@ -57,7 +59,7 @@ class CpuDevice : public Device { #ifdef USE_CUDA class CudaDevice : public Device { public: - ~CudaDevice(); + ~CudaDevice() override; void SetDevice() const override; void Synchronize() const override; diff --git a/infini_train/include/nn/modules/container.h b/infini_train/include/nn/modules/container.h index d1356712..4da8b3e6 100644 --- a/infini_train/include/nn/modules/container.h +++ b/infini_train/include/nn/modules/container.h @@ -35,10 +35,12 @@ class ModuleList : public CloneableModule { std::vector> Forward(const std::vector> &input_tensors) override; - auto begin() { return module_list_.begin(); } - auto end() { return module_list_.end(); } - auto begin() const { return module_list_.begin(); } - auto end() const { return module_list_.end(); } + std::vector>::iterator begin(); + std::vector>::iterator end(); + std::vector>::const_iterator begin() const; + std::vector>::const_iterator end() const; + + std::shared_ptr &operator[](std::size_t idx); private: std::vector> module_list_; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 08c736e0..9bc78bcc 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -25,6 +25,10 @@ class Module : public std::enable_shared_from_this { public: static constexpr char kUndefinedType[] = "Undefined"; + static constexpr char kPPFirstStageName[] = "__pp_first_stage"; + static constexpr char kPPLastStageName[] = "__pp_last_stage"; + static constexpr char kPPChunkNamePrefix[] = "__pp_chunk_"; + explicit Module(); explicit Module(const std::string &type); Module(const Module &) = default; diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index f177b9dc..480c1286 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -27,7 +27,7 @@ class GlobalEnv { static GlobalEnv &Instance(); void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size); + int pipeline_parallel_size, int virtual_pipeline_parallel_size); int nnodes() const; @@ -51,6 +51,8 @@ class GlobalEnv { int pipeline_parallel_size() const; + int virtual_pipeline_parallel_size() const; + Layout layout() const; private: @@ -75,6 +77,7 @@ class GlobalEnv { int data_parallel_size_ = 1; int pipeline_parallel_size_ = 1; + int virtual_pipeline_parallel_size_ = 1; mutable std::mutex mutex_; bool initialized_ = false; @@ -83,9 +86,9 @@ class GlobalEnv { }; inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size) { + int pipeline_parallel_size, int virtual_pipeline_parallel) { GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled, - pipeline_parallel_size); + pipeline_parallel_size, virtual_pipeline_parallel); } inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); } @@ -100,6 +103,7 @@ inline int GetSequenceParallelSize() { return GlobalEnv::Instance().sequence_par inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); } inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); } inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); } +inline int GetVirtualPipelineParallelSize() { return GlobalEnv::Instance().virtual_pipeline_parallel_size(); } // ========================= // Layout Helper Functions diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h index c58f1da9..2ddd0918 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_parallel.h +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -18,28 +18,39 @@ class PipelineSchedule; extern thread_local int pp_rank; +struct StageInfo { + bool is_first_stage; + bool is_last_stage; + + // Layer index ranges for chunks assigned to this pipeline stage. + // Each element is a pair: (inclusive_start_layer, exclusive_end_layer) + std::vector> layer_ranges_per_chunk; +}; + class PipelineParallel : public Module { public: PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, const std::vector> &recv_shape, int rank, - const std::shared_ptr &optimizer, int device_id); + const std::shared_ptr &optimizer, int device_id, int vpp); float TrainStep(const std::vector> &input, const std::vector> &target, const std::shared_ptr &loss_fn, DataType dtype); - static std::tuple GetStageInfo(int total_layers, int pp_size, int pp_rank); + static StageInfo GetStageInfo(int total_layers, int pp_size, int pp_rank, int chunks_per_stage = 1); + + std::vector> *mutable_chunks(); private: + void BuildPipelineStage(const std::shared_ptr &optimizer, + const std::vector> &recv_shape, int device_id, + std::vector> &&chunks); + + void SetupSchedule(int num_micro_batches); + int num_stages_ = -1; int rank_ = -1; - std::shared_ptr pipeline_stage_ = nullptr; std::shared_ptr schedule_ = nullptr; - - void BuildPipelineStage(const std::shared_ptr &model, const std::shared_ptr &optimizer, - const std::vector> &recv_shape, int device_id); - - void SetupSchedule(int num_micro_batches); + std::shared_ptr pipeline_stage_ = nullptr; }; - } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_schedule.h b/infini_train/include/nn/parallel/pp/pipeline_schedule.h index 4174716a..3a8a9f75 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_schedule.h +++ b/infini_train/include/nn/parallel/pp/pipeline_schedule.h @@ -28,35 +28,34 @@ class PipelineSchedule { virtual float StepMicroBatches(const std::vector> &arg_mbs, const std::vector> &target_mbs, - const std::shared_ptr &loss_fn, DataType dtype) - = 0; + const std::shared_ptr &loss_fn, DataType dtype); - std::vector> ReceiveFromPrev(); - std::vector> SendToNext(const std::vector> &tensors); + std::vector> ReceiveFromPrev(int peer_rank); + std::vector> SendToNext(const std::vector> &tensors, int peer_rank); protected: int num_micro_batches_ = -1; std::shared_ptr stage_ = nullptr; }; -class ScheduleGPipe : public PipelineSchedule { +class PipelineParallelScheduler { public: - ScheduleGPipe(std::shared_ptr stage, int num_stages, int num_micro_batches) - : PipelineSchedule(std::move(stage), num_stages, num_micro_batches){}; - - float StepMicroBatches(const std::vector> &arg_mbs, - const std::vector> &target_mbs, - const std::shared_ptr &loss_fn, DataType dtype) override; -}; - -class Schedule1F1B : public PipelineSchedule { -public: - Schedule1F1B(std::shared_ptr stage, int num_stages, int num_micro_batches) - : PipelineSchedule(std::move(stage), num_stages, num_micro_batches){}; - - float StepMicroBatches(const std::vector> &arg_mbs, - const std::vector> &target_mbs, - const std::shared_ptr &loss_fn, DataType dtype) override; + struct Task { + int step; + int microbatch_id; + int global_chunk_id; + int local_chunk_idx; + bool is_forward; + int stage_id; + bool is_first_chunk; + bool is_last_chunk; + }; + + static Task CreateTask(int step, int mb, int global_chunk, int num_stages, int total_chunks, bool is_forward); + + static std::vector GenerateGPipeSchedule(int n, int num_stages, int vpp_size); + + static std::vector GenerateInterleaved1F1BSchedule(int n, int num_stages, int vpp_size); }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index b7679d8c..52a776ea 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -16,21 +16,25 @@ namespace infini_train::nn::parallel { class PipelineStage { public: - PipelineStage(const std::shared_ptr &model, int stage_index, int num_stages, - const std::vector> &recv_shape, std::shared_ptr optimizer, - int device_id); - - std::vector> ForwardOneChunk(const std::vector> &inputs); - - bool IsFirstStage() const { return stage_index_ == 0; } - bool IsLastStage() const { return stage_index_ == num_stages_ - 1; } - int stage_index() const { return stage_index_; } - int prev_rank() const { return prev_rank_; } - int next_rank() const { return next_rank_; } - int num_stages() const { return num_stages_; } - const Device *device() const { return device_; } - const std::vector> &recv_shape() const { return recv_shape_; } - std::shared_ptr optimizer() { return optimizer_; } + PipelineStage(int stage_index, int num_stages, const std::vector> &recv_shape, + std::shared_ptr optimizer, int device_id, std::vector> &&chunks); + + std::vector> ForwardOneChunk(const std::vector> &inputs, + int local_chunk_idx = 0); + + bool IsFirstStage() const; + bool IsLastStage() const; + + int stage_index() const; + int prev_rank() const; + int next_rank() const; + int num_stages() const; + + const Device *device() const; + const std::vector> &recv_shape() const; + std::shared_ptr optimizer(); + const std::vector> &chunks(); + std::vector> *mutable_chunks(); private: int stage_index_ = -1; @@ -38,7 +42,7 @@ class PipelineStage { int prev_rank_ = -1; int next_rank_ = -1; const Device *device_ = nullptr; - std::shared_ptr model_ = nullptr; + std::vector> chunks_; std::shared_ptr optimizer_ = nullptr; std::vector> recv_shape_; }; diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index bf852a37..d739f67d 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -30,6 +30,8 @@ namespace infini_train::nn::parallel { class ProcessGroup { public: + virtual ~ProcessGroup() = default; + virtual int GetGroupRank(int global_rank) const; // Asynchronous communication APIs (Compute / Communication stream decoupled) @@ -90,7 +92,7 @@ class ProcessGroupNCCL final : public ProcessGroup { public: explicit ProcessGroupNCCL(const std::string &process_group_name, const std::vector &device_indices); - ~ProcessGroupNCCL(); + ~ProcessGroupNCCL() override; // Asynchronous communication APIs (Compute / Communication stream decoupled) std::shared_ptr AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op, diff --git a/infini_train/src/nn/modules/container.cc b/infini_train/src/nn/modules/container.cc index 396306dc..6df46663 100644 --- a/infini_train/src/nn/modules/container.cc +++ b/infini_train/src/nn/modules/container.cc @@ -41,4 +41,12 @@ ModuleList::ModuleList(std::vector> &&layers) std::vector> ModuleList::Forward(const std::vector> &input_tensors) { LOG(FATAL) << "Not implemented"; } + +std::vector>::iterator ModuleList::begin() { return module_list_.begin(); } +std::vector>::iterator ModuleList::end() { return module_list_.end(); } +std::vector>::const_iterator ModuleList::begin() const { return module_list_.begin(); } +std::vector>::const_iterator ModuleList::end() const { return module_list_.end(); } + +std::shared_ptr &ModuleList::operator[](std::size_t idx) { return module_list_.at(idx); } + } // namespace infini_train::nn diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 3f757fa2..4e0c6a28 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -20,10 +20,22 @@ const std::string &Module::type() const { return type_; } std::vector> Module::Parameters() const { std::vector> params; - for (auto &[_, param] : parameters_) { params.push_back(param); } - for (auto &[_, module] : modules_) { - for (auto ¶m : module->Parameters()) { params.push_back(param); } + std::unordered_set visited; + + auto AddIfUnvisited = [&](const std::shared_ptr ¶m) { + if (visited.insert(param.get()).second) { + params.push_back(param); + } + }; + + // Add parameters of this module + for (const auto &[_, param] : parameters_) { AddIfUnvisited(param); } + + // Recursively add parameters of submodules + for (const auto &[_, module] : modules_) { + for (const auto ¶m : module->Parameters()) { AddIfUnvisited(param); } } + return params; } @@ -100,6 +112,9 @@ std::unordered_map> Module::StateDict() con for (auto &[name, param] : parameters_) { state.emplace(name, param); } for (auto &[name, buffer] : buffers_) { state.emplace(name, buffer); } for (auto &[name, module] : modules_) { + if (name.starts_with("__pp")) { + continue; + } for (auto &[sub_name, param] : module->StateDict()) { state.emplace(name + "." + sub_name, param); } } return state; diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index 9b8704a0..a25a7d16 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -56,5 +56,4 @@ DistributedDataParallel::Forward(const std::vector> &inp } return outputs; } - } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 4c00da44..39cd95dd 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -90,7 +90,7 @@ GlobalEnv &GlobalEnv::Instance() { } void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, - int pipeline_parallel_size) { + int pipeline_parallel_size, int virtual_pipeline_parallel_size) { std::lock_guard lock(mutex_); CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; @@ -106,6 +106,7 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq tensor_parallel_size_ = tensor_parallel_size; sequence_parallel_enabled_ = sequence_parallel_enabled; pipeline_parallel_size_ = pipeline_parallel_size; + virtual_pipeline_parallel_size_ = virtual_pipeline_parallel_size; data_parallel_size_ = world_size_ / tensor_parallel_size_ / pipeline_parallel_size_; layout_.sizes[DP] = data_parallel_size_; @@ -171,6 +172,11 @@ int GlobalEnv::pipeline_parallel_size() const { return pipeline_parallel_size_; } +int GlobalEnv::virtual_pipeline_parallel_size() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return virtual_pipeline_parallel_size_; +} + Layout GlobalEnv::layout() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; return layout_; diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc index 6da3b1eb..bbdceaea 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -3,6 +3,7 @@ #include #include +#include #include "infini_train/include/nn/modules/container.h" #include "infini_train/include/nn/modules/module.h" @@ -17,14 +18,15 @@ constexpr char kModuleName[] = "module"; thread_local int pp_rank = 0; -void PipelineParallel::BuildPipelineStage(const std::shared_ptr &module, - const std::shared_ptr &optimizer, - const std::vector> &recv_shape, int device_id) { - pipeline_stage_ = std::make_shared(module, rank_, num_stages_, recv_shape, optimizer, device_id); +void PipelineParallel::BuildPipelineStage(const std::shared_ptr &optimizer, + const std::vector> &recv_shape, int device_id, + std::vector> &&chunks) { + pipeline_stage_ + = std::make_shared(rank_, num_stages_, recv_shape, optimizer, device_id, std::move(chunks)); } void PipelineParallel::SetupSchedule(int num_micro_batches) { - schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches); + schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches); } float PipelineParallel::TrainStep(const std::vector> &input, @@ -39,33 +41,69 @@ float PipelineParallel::TrainStep(const std::vector> &in return schedule_->Step(stage_input, stage_target, loss_fn, dtype); } -std::tuple PipelineParallel::GetStageInfo(int total_layers, int pp_size, int pp_rank) { - bool is_first_stage = (pp_rank == 0); - bool is_last_stage = (pp_rank == pp_size - 1); - - int layers_per_stage = total_layers / pp_size; - int remainder = total_layers % pp_size; - int start_layer, end_layer; - if (pp_rank < remainder) { - start_layer = pp_rank * (layers_per_stage + 1); - end_layer = start_layer + layers_per_stage + 1; - } else { - start_layer = pp_rank * layers_per_stage + remainder; - end_layer = start_layer + layers_per_stage; +StageInfo PipelineParallel::GetStageInfo(int total_layers, int pp_size, int rank, int chunks_per_stage) { + bool is_first_stage = (rank == 0); + bool is_last_stage = (rank == pp_size - 1); + + std::vector> layer_ranges_per_chunk; + + int layers_per_chunk = total_layers / (pp_size * chunks_per_stage); + int remainder = total_layers % (pp_size * chunks_per_stage); + + for (int local_chunk_idx = 0; local_chunk_idx < chunks_per_stage; ++local_chunk_idx) { + int global_chunk_idx = local_chunk_idx * pp_size + rank; + + if (global_chunk_idx * layers_per_chunk >= total_layers) { + break; + } + + int chunk_start = global_chunk_idx * layers_per_chunk; + int chunk_end = chunk_start + layers_per_chunk; + + if (global_chunk_idx < remainder) { + // Assign an additional layer to each of the first remainder chunks + chunk_start = global_chunk_idx * (layers_per_chunk + 1); + chunk_end = chunk_start + (layers_per_chunk + 1); + } else { + chunk_start = remainder * (layers_per_chunk + 1) + (global_chunk_idx - remainder) * layers_per_chunk; + chunk_end = chunk_start + layers_per_chunk; + } + + chunk_end = std::min(chunk_end, total_layers); + if (chunk_start < chunk_end) { + layer_ranges_per_chunk.push_back({chunk_start, chunk_end}); + } } - return {is_first_stage, is_last_stage, start_layer, end_layer}; + return {is_first_stage, is_last_stage, layer_ranges_per_chunk}; } PipelineParallel::PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, const std::vector> &recv_shape, int pp_rank, - const std::shared_ptr &optimizer, int device_id) + const std::shared_ptr &optimizer, int device_id, int chunk_size) : num_stages_(num_stages), rank_(pp_rank) { modules_[kModuleName] = std::move(module); - BuildPipelineStage(module, optimizer, recv_shape, device_id); + int stage_id = pp_rank; + int stage_size = num_stages; + + std::vector> chunks; + for (int chunk_id = 0; chunk_id < chunk_size; ++chunk_id) { + std::vector> chunk_parts; + if (chunk_id == 0 && stage_id == 0) { + chunk_parts.push_back(module->mutable_module(kPPFirstStageName)); + } + chunk_parts.push_back(module->mutable_module(kPPChunkNamePrefix + std::to_string(chunk_id))); + if (chunk_id == chunk_size - 1 && stage_id == stage_size - 1) { + chunk_parts.push_back(module->mutable_module(kPPLastStageName)); + } + chunks.push_back(std::make_shared(std::move(chunk_parts))); + } + + BuildPipelineStage(optimizer, recv_shape, device_id, std::move(chunks)); SetupSchedule(num_micro_batches); } +std::vector> *PipelineParallel::mutable_chunks() { return pipeline_stage_->mutable_chunks(); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index f6081f56..95dd3bbc 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -12,6 +12,7 @@ #include "infini_train/include/device.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/pp/pipeline_stage.h" #include "infini_train/include/nn/parallel/pp/send_recv.h" #include "infini_train/include/optimizer.h" @@ -19,30 +20,30 @@ namespace infini_train::nn::parallel { -float PipelineSchedule::Step(std::shared_ptr input, std::shared_ptr target, - const std::shared_ptr &loss_fn, DataType dtype) { - std::vector> micro_batches(num_micro_batches_); - std::vector> target_mbs(num_micro_batches_); - if (stage_->IsFirstStage()) { - micro_batches = input->Split(input->Dims()[0] / num_micro_batches_); - } +void PrintScheduleTable(const std::vector &schedule, int n, int num_stages, + int vpp_size) { + int total_global_chunks = num_stages * vpp_size; - if (stage_->IsLastStage()) { - target_mbs = target->Split(target->Dims()[0] / num_micro_batches_); - } + LOG(INFO) << std::format("=== Schedule Table ===\n" + "n: {}, stages: {}, vpp: {}, total_chunks: {}", + n, num_stages, vpp_size, total_global_chunks); + LOG(INFO) << ""; + LOG(INFO) << "Step | Type | Microbatch | Global Chunk | Local Chunk | Stage"; + LOG(INFO) << "-----|-----------|------------|--------------|-------------|-------"; - const auto &optimizer = stage_->optimizer(); + for (const auto &task : schedule) { + int owning_stage = task.global_chunk_id % num_stages; + int local_chunk = task.global_chunk_id / num_stages; - optimizer->ZeroGrad(); + std::string type_str = task.is_forward ? "Forward" : "Backward"; - float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn, dtype); - - optimizer->Step(); - - return lossf; + auto s_info = std::format("{:4} | {:<9} | {:>10} | {:>12} | {:>11} | {:>5}", task.step, type_str, + task.microbatch_id, task.global_chunk_id, local_chunk, owning_stage); + LOG(INFO) << s_info; + } } -std::vector> PipelineSchedule::ReceiveFromPrev() { +std::vector> PipelineSchedule::ReceiveFromPrev(int peer_rank) { std::vector> recv_tensors; auto &shapes = stage_->recv_shape(); for (size_t i = 0; i < shapes.size(); ++i) { @@ -52,81 +53,240 @@ std::vector> PipelineSchedule::ReceiveFromPrev() { recv_tensors.push_back(tensor); } - return IRecv(recv_tensors, stage_->device(), stage_->stage_index(), stage_->prev_rank()); + return IRecv(recv_tensors, stage_->device(), stage_->stage_index(), peer_rank); } -std::vector> PipelineSchedule::SendToNext(const std::vector> &tensors) { - return ISend(tensors, stage_->device(), stage_->stage_index(), stage_->next_rank(), stage_->recv_shape()); +std::vector> PipelineSchedule::SendToNext(const std::vector> &tensors, + int peer_rank) { + return ISend(tensors, stage_->device(), stage_->stage_index(), peer_rank, stage_->recv_shape()); } -float ScheduleGPipe::StepMicroBatches(const std::vector> µbatch_inputs, - const std::vector> µbatch_targets, - const std::shared_ptr &loss_fn, DataType dtype) { - const auto n = num_micro_batches_; - if (n == 0) { - return 0.0f; - } +PipelineParallelScheduler::Task PipelineParallelScheduler::CreateTask(int step, int mb, int global_chunk, + int num_stages, int total_chunks, + bool is_forward) { + PipelineParallelScheduler::Task task; + task.step = step; + task.microbatch_id = mb; + task.global_chunk_id = global_chunk; + task.local_chunk_idx = global_chunk / num_stages; + task.is_forward = is_forward; + task.stage_id = global_chunk % num_stages; + task.is_last_chunk = (global_chunk == total_chunks - 1); + task.is_first_chunk = (global_chunk == 0); + return task; +} - std::vector>> outputs(n); +std::vector PipelineParallelScheduler::GenerateGPipeSchedule(int n, int num_stages, + int vpp_size) { + std::vector schedule; + int total_global_chunks = num_stages * vpp_size; + int total_steps = n + total_global_chunks - 1; // ======== Forward Pass ======== - for (int mb = 0; mb < n; ++mb) { - infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + for (int step = 0; step < total_steps; ++step) { + for (int mb = 0; mb < n; ++mb) { + int global_chunk_id = step - mb; + if (global_chunk_id >= 0 && global_chunk_id < total_global_chunks) { + auto is_forward = true; + auto task = CreateTask(step, mb, global_chunk_id, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } + } + } - std::vector> inputs; - if (stage_->IsFirstStage()) { - inputs = {microbatch_inputs[mb]}; - } else { - inputs = ReceiveFromPrev(); + // ======== Backward Pass ======== + for (int step = 0; step < total_steps; ++step) { + for (int mb = 0; mb < n; ++mb) { + int global_chunk_id = (total_steps - 1 - step) - mb; + if (global_chunk_id >= 0 && global_chunk_id < total_global_chunks) { + auto is_forward = false; + auto task + = CreateTask(step + total_steps, mb, global_chunk_id, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } + } + } + + // sorted according to step, local_chunk_idx + std::sort(schedule.begin(), schedule.end(), [](const Task &a, const Task &b) { + if (a.step != b.step) { + return a.step < b.step; } - outputs[mb] = stage_->ForwardOneChunk(inputs); + return a.local_chunk_idx < b.local_chunk_idx; + }); + + return schedule; +} + +std::vector +PipelineParallelScheduler::GenerateInterleaved1F1BSchedule(int n, int num_stages, int vpp_size) { + std::vector schedule; - if (!stage_->IsLastStage()) { - outputs[mb] = SendToNext(outputs[mb]); + if (n <= 0 || num_stages <= 0 || vpp_size <= 0) { + return schedule; + } + + int total_global_chunks = num_stages * vpp_size; + + int warmup_steps = total_global_chunks - 1; + int total_steps = 2 * warmup_steps + n; + + // ================ Warm-up ================ + for (int step = 0; step < warmup_steps; ++step) { + for (int mb = 0; mb < n; ++mb) { + int forward_global_chunk = step - mb; + if (forward_global_chunk >= 0 && forward_global_chunk < total_global_chunks) { + auto is_forward = true; + auto task = CreateTask(step, mb, forward_global_chunk, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } } } - // ======== Backward Pass ======== - float total_loss = 0.0f; - if (!stage_->IsLastStage()) { + // ================ Steady ================ + for (int step = warmup_steps; step < warmup_steps + n; ++step) { + int stable_step = step - warmup_steps; + for (int mb = 0; mb < n; ++mb) { - auto out_tensor = outputs[mb][0]; + // Forward + int forward_global_chunk = step - mb; + if (forward_global_chunk >= 0 && forward_global_chunk < total_global_chunks) { + auto is_forward = true; + auto task = CreateTask(step, mb, forward_global_chunk, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } - auto dummy_gradient - = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + // Backward + int backward_global_chunk = (total_global_chunks - 1) - (stable_step - mb); - out_tensor->Backward(dummy_gradient); + if (backward_global_chunk >= 0 && backward_global_chunk < total_global_chunks) { + auto is_forward = false; + auto task = CreateTask(step, mb, backward_global_chunk, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } } - } else { + } + + // ================ Cool-down ================ + for (int step = warmup_steps + n; step < total_steps; ++step) { for (int mb = 0; mb < n; ++mb) { - auto target = microbatch_targets[mb]; - auto output = outputs[mb][0]; + int backward_step = step - (warmup_steps); + int backward_global_chunk = (total_global_chunks - 1) - (backward_step - mb); + if (backward_global_chunk >= 0 && backward_global_chunk < total_global_chunks) { + auto is_forward = false; + auto task = CreateTask(step, mb, backward_global_chunk, num_stages, total_global_chunks, is_forward); + schedule.push_back(task); + } + } + } + + return schedule; +} + +float PipelineSchedule::StepMicroBatches(const std::vector> µbatch_inputs, + const std::vector> µbatch_targets, + const std::shared_ptr &loss_fn, DataType dtype) { + int n = num_micro_batches_; + int num_stages = stage_->num_stages(); + int stage_idx = stage_->stage_index(); + int vpp_size = global::GetVirtualPipelineParallelSize(); + + auto schedule = PipelineParallelScheduler::GenerateGPipeSchedule(n, num_stages, vpp_size); + + static bool has_printed = false; + if (!has_printed && stage_idx == 0) { + PrintScheduleTable(schedule, n, num_stages, vpp_size); + has_printed = true; + } + + float total_loss = 0.0f; - if (!target || !output) { - LOG(FATAL) << "Output or target is null at mb=" << mb; + std::vector>>> activations( + vpp_size, std::vector>>(n)); + + for (size_t i = 0; i < schedule.size(); ++i) { + const auto &task = schedule[i]; + if (task.stage_id != stage_idx) { + continue; + } + + int mb = task.microbatch_id; + if (task.is_forward) { + infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + + std::vector> inputs; + + if (task.is_first_chunk) { + inputs = {microbatch_inputs[mb]}; + } else { + if (stage_->IsFirstStage()) { + inputs = ReceiveFromPrev(num_stages - 1); + } else { + inputs = ReceiveFromPrev(stage_->prev_rank()); + } } - std::shared_ptr loss; - { - infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + activations[task.local_chunk_idx][mb] = stage_->ForwardOneChunk(inputs, task.local_chunk_idx); - auto target_on_device = target->To(output->GetDevice()); - loss = loss_fn->Forward({output, std::make_shared(target_on_device)})[0]; - if (!loss) { - LOG(FATAL) << "[ERROR] loss is nullptr at mb = " << mb; + if (!task.is_last_chunk) { + if (stage_->IsLastStage()) { + SendToNext(activations[task.local_chunk_idx][mb], 0); + } else { + SendToNext(activations[task.local_chunk_idx][mb], stage_->next_rank()); } - loss = loss / n; } + } else { + if (task.is_last_chunk) { + auto target = microbatch_targets[mb]; + std::shared_ptr loss; + { + infini_train::AutocastGuard autocast_guard(stage_->device()->Type(), dtype); + + auto target_on_device = target->To(activations[task.local_chunk_idx][mb][0]->GetDevice()); + loss = loss_fn->Forward( + {activations[task.local_chunk_idx][mb][0], std::make_shared(target_on_device)})[0]; + loss = loss / n; + } + total_loss + += static_cast(loss->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; - auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); - total_loss += static_cast(loss_cpu.DataPtr())[0]; + loss->Backward(); + } else { + auto out_tensor = activations[task.local_chunk_idx][mb][0]; - loss->Backward(); + auto dummy_gradient + = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + + out_tensor->Backward(dummy_gradient); + } } } return total_loss; } +float PipelineSchedule::Step(std::shared_ptr input, std::shared_ptr target, + const std::shared_ptr &loss_fn, DataType dtype) { + std::vector> micro_batches(num_micro_batches_); + std::vector> target_mbs(num_micro_batches_); + if (stage_->IsFirstStage()) { + micro_batches = input->Split(input->Dims()[0] / num_micro_batches_); + } + + if (stage_->IsLastStage()) { + target_mbs = target->Split(target->Dims()[0] / num_micro_batches_); + } + + const auto &optimizer = stage_->optimizer(); + + optimizer->ZeroGrad(); + + float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn, dtype); + + optimizer->Step(); + + return lossf; +} + } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index bdb8cba9..582b9bd2 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -10,18 +10,35 @@ namespace infini_train::nn::parallel { -PipelineStage::PipelineStage(const std::shared_ptr &model, int stage_index /* pp_rank */, - int num_stages /* pp_size */, const std::vector> &recv_shape, - std::shared_ptr optimizer, int device_id) - : model_(model), stage_index_(stage_index), num_stages_(num_stages), - prev_rank_(stage_index > 0 ? stage_index - 1 : -1), +PipelineStage::PipelineStage(int stage_index /* pp_rank */, int num_stages /* pp_size */, + const std::vector> &recv_shape, std::shared_ptr optimizer, + int device_id, std::vector> &&chunks) + : stage_index_(stage_index), num_stages_(num_stages), prev_rank_(stage_index > 0 ? stage_index - 1 : -1), next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), optimizer_(std::move(optimizer)), - device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)) {} + device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)), + chunks_(std::move(chunks)) {} -std::vector> -PipelineStage::ForwardOneChunk(const std::vector> &inputs) { - return model_->Forward(inputs); +std::vector> PipelineStage::ForwardOneChunk(const std::vector> &inputs, + int local_chunk_idx) { + if (local_chunk_idx < 0 || local_chunk_idx >= static_cast(chunks_.size())) { + LOG(FATAL) << "PipelineStage::ForwardOneChunk: local_chunk_idx=" << local_chunk_idx << " out of range [0, " + << chunks_.size() << ")"; + } + return chunks_[local_chunk_idx]->Forward(inputs); } +bool PipelineStage::IsFirstStage() const { return stage_index_ == 0; } +bool PipelineStage::IsLastStage() const { return stage_index_ == num_stages_ - 1; } + +int PipelineStage::stage_index() const { return stage_index_; } +int PipelineStage::prev_rank() const { return prev_rank_; } +int PipelineStage::next_rank() const { return next_rank_; } +int PipelineStage::num_stages() const { return num_stages_; } + +const Device *PipelineStage::device() const { return device_; } +const std::vector> &PipelineStage::recv_shape() const { return recv_shape_; } +std::shared_ptr PipelineStage::optimizer() { return optimizer_; } +const std::vector> &PipelineStage::chunks() { return chunks_; } +std::vector> *PipelineStage::mutable_chunks() { return &chunks_; } } // namespace infini_train::nn::parallel diff --git a/scripts/test_config.json b/scripts/test_config.json index 93d9f955..f2d9d9be 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -141,6 +141,30 @@ }, { "id": "7", + "args": { + "dtype": "float32", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 4, + "virtual_pipeline_parallel": 2 + } + }, + { + "id": "7_bfloat16", + "args": { + "dtype": "bfloat16", + "nthread_per_process": 4, + "num_iteration": 10, + "batch_size": 10, + "total_batch_size": 5120, + "pipeline_parallel": 4, + "virtual_pipeline_parallel": 2 + } + }, + { + "id": "8", "args": { "dtype": "float32", "nthread_per_process": 8, @@ -149,11 +173,12 @@ "total_batch_size": 5120, "tensor_parallel": 2, "sequence_parallel": true, - "pipeline_parallel": 2 + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2 } }, { - "id": "7_bfloat16", + "id": "8_bfloat16", "args": { "dtype": "bfloat16", "nthread_per_process": 8, @@ -162,7 +187,8 @@ "total_batch_size": 5120, "tensor_parallel": 2, "sequence_parallel": true, - "pipeline_parallel": 2 + "pipeline_parallel": 2, + "virtual_pipeline_parallel": 2 } } ]