Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ DEFINE_int32(
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
Expand Down Expand Up @@ -187,15 +188,35 @@ void Train(const nn::parallel::Rank &rank) {
LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported.";
}

// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
if (ddp_world_size > 1) {
auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::SGD>(optimizer),
rank.thread_rank(), std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank.thread_rank());
}
}
} else if (ddp_world_size > 1) {
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
}

auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size,
ddp_rank, ddp_world_size);
Expand All @@ -216,9 +237,6 @@ void Train(const nn::parallel::Rank &rank) {
tokenizer = std::make_unique<Tokenizer>(FLAGS_tokenizer_bin);
}

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
Expand All @@ -227,17 +245,6 @@ void Train(const nn::parallel::Rank &rank) {
loss_fn->To(device);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, std::make_shared<optimizers::SGD>(optimizer),
rank.thread_rank());
}

LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
Expand Down Expand Up @@ -293,6 +300,7 @@ void Train(const nn::parallel::Rank &rank) {
auto logits = model->Forward({x, y})[0];
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
// FIXME(jym): verify gradient accumulation precision
loss = loss / grad_accum_steps;

// disable autocast for the current step (backward is not under autocast)
Expand Down Expand Up @@ -356,7 +364,7 @@ int main(int argc, char *argv[]) {
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel);
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
Loading