Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,35 @@ void Train(const nn::parallel::Rank &rank) {
LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported.";
}

// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
if (ddp_world_size > 1) {
auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::SGD>(optimizer),
rank.thread_rank(), std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank.thread_rank());
}
}
} else if (ddp_world_size > 1) {
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
}

auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size,
ddp_rank, ddp_world_size);
Expand All @@ -217,9 +237,6 @@ void Train(const nn::parallel::Rank &rank) {
tokenizer = std::make_unique<Tokenizer>(FLAGS_tokenizer_bin);
}

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
Expand All @@ -228,17 +245,6 @@ void Train(const nn::parallel::Rank &rank) {
loss_fn->To(device);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, std::make_shared<optimizers::SGD>(optimizer),
rank.thread_rank());
}

LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
Expand Down
366 changes: 158 additions & 208 deletions example/gpt2/net.cc

Large diffs are not rendered by default.

68 changes: 43 additions & 25 deletions example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

#include "glog/logging.h"

#include "infini_train/include/device.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
#include "infini_train/include/nn/parallel/pp/pipeline_stage.h"
#include "infini_train/include/tensor.h"

struct GPT2Config {
Expand Down Expand Up @@ -71,15 +72,51 @@ class Block : public infini_train::nn::CloneableModule<Block> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
};

class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
class GPT2FirstStage : public infini_train::nn::CloneableModule<GPT2FirstStage> {
public:
static constexpr char kWTELayerName[] = "wte";
static constexpr char kWPELayerName[] = "wpe";

explicit GPT2FirstStage(const GPT2Config &config);

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
const GPT2Config config_;
};

class GPT2Chunk : public infini_train::nn::CloneableModule<GPT2Chunk> {
public:
static constexpr char kHLayerName[] = "h";

GPT2Chunk(const GPT2Config &config, int start_layer, int end_layer);

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
const GPT2Config config_;
};

class GPT2LastStage : public infini_train::nn::CloneableModule<GPT2LastStage> {
public:
static constexpr char kLnFLayerName[] = "ln_f";
static constexpr char kTransformerLayerName[] = "transformer";
static constexpr char kLMHeadLayerName[] = "lm_head";

explicit GPT2LastStage(const GPT2Config &config);

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
const GPT2Config config_;
};

class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
public:
static constexpr char kTransformerLayerName[] = "transformer";

enum class ModelType : int8_t {
kGPT2,
kGPT2Medium,
Expand All @@ -92,31 +129,12 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

std::vector<std::shared_ptr<infini_train::nn::Module>> BuildChunks(int pp_rank) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);

private:
GPT2Config config_;
};

class GPT2Chunk : public infini_train::nn::CloneableModule<GPT2Chunk> {
public:
GPT2Chunk(GPT2 *parent, int layer_begin, int chunk_layers, bool has_embedding, bool has_lm_head,
const GPT2Config &config)
: parent_(parent), layer_begin_(layer_begin), chunk_layers_(chunk_layers), has_embedding_(has_embedding),
has_lm_head_(has_lm_head), config_(config) {}

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
int GetChunkSize() const;

private:
GPT2 *parent_ = nullptr;
int layer_begin_ = 0;
int chunk_layers_ = 0;
bool has_embedding_ = false;
bool has_lm_head_ = false;

GPT2Config config_;
const GPT2Config config_;
const infini_train::nn::parallel::StageInfo stage_info_;
};
47 changes: 26 additions & 21 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,35 @@ void Train(const nn::parallel::Rank &rank) {
LOG(FATAL) << "Rank " << rank.GlobalRank() << ": Datatype " << FLAGS_dtype << " not supported.";
}

// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
if (ddp_world_size > 1) {
auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::Adam>(optimizer),
rank.thread_rank(), std::dynamic_pointer_cast<LLaMA3>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank.thread_rank());
}
}
} else if (ddp_world_size > 1) {
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
}

auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);

DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size,
ddp_rank, ddp_world_size);
Expand All @@ -197,27 +216,13 @@ void Train(const nn::parallel::Rank &rank) {
tokenizer = std::make_unique<Tokenizer>(FLAGS_tokenizer_bin);
}

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
: std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
loss_fn->To(device);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, std::make_shared<optimizers::Adam>(optimizer),
rank.thread_rank());
}

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
const bool last_step = step == FLAGS_num_iteration;

Expand Down
Loading