diff --git a/tools/fsdp_sft.py b/tools/fsdp_sft.py index 8caded648..1ea454e2f 100644 --- a/tools/fsdp_sft.py +++ b/tools/fsdp_sft.py @@ -50,9 +50,9 @@ from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler, get_dp_mesh, get_dp_world_size, get_sp_group, get_sp_mesh, - get_sp_world_size, - reduce_sequence_parallel_loss, - setup_parallel, split_for_sequence_parallel) + get_sp_world_size, reduce_sp_loss_for_debug, + rescale_sp_loss, setup_parallel, + split_for_sequence_parallel) from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit, all_required_grad_wrap_policy, checkpoint_check_fn, dp_lazy_init, @@ -729,9 +729,12 @@ def warmup_fn(x): loss = outputs.loss if get_sp_world_size() > 1: - tokens_cal_loss = (labels != -100).sum() - loss = reduce_sequence_parallel_loss( - loss, tokens_cal_loss, sp_group) + loss = rescale_sp_loss(outputs.loss, labels, sp_group) + + if args.debug: + loss_debug = reduce_sp_loss_for_debug( + outputs.loss, labels, sp_group) + logger.info(loss_debug) avg_iter_loss = loss / iters_per_step diff --git a/xtuner/_lite/parallel/sequence/__init__.py b/xtuner/_lite/parallel/sequence/__init__.py index 2dbdcaacf..404c45b83 100644 --- a/xtuner/_lite/parallel/sequence/__init__.py +++ b/xtuner/_lite/parallel/sequence/__init__.py @@ -8,7 +8,7 @@ pad_for_sequence_parallel) from .ops import (gather_for_sequence_parallel, gather_forward_split_backward, split_for_sequence_parallel, split_forward_gather_backward) -from .reduce_loss import reduce_sequence_parallel_loss +from .rescale_loss import reduce_sp_loss_for_debug, rescale_sp_loss __all__ = [ 'sequence_parallel_wrapper', 'pre_process_for_sequence_parallel_attn', @@ -16,5 +16,5 @@ 'init_dist', 'gather_for_sequence_parallel', 'split_forward_gather_backward', 'gather_forward_split_backward', 'pad_cumulative_len_for_sequence_parallel', 'pad_for_sequence_parallel', - 'reduce_sequence_parallel_loss' + 'rescale_sp_loss', 'reduce_sp_loss_for_debug' ] diff --git a/xtuner/_lite/parallel/sequence/reduce_loss.py b/xtuner/_lite/parallel/sequence/reduce_loss.py deleted file mode 100644 index d19fb4731..000000000 --- a/xtuner/_lite/parallel/sequence/reduce_loss.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import torch.distributed as dist - -from ..setup import get_sp_group - - -class _ReduceLoss(torch.autograd.Function): - - @staticmethod - def forward(ctx, mean_loss, loss_scale, process_group): - ctx.mode = process_group - if loss_scale == 0: - # convert nan to 0 just for logging - mean_loss = torch.nan_to_num(mean_loss) - loss_sum = mean_loss * loss_scale - dist.all_reduce(loss_sum, group=process_group) - dist.all_reduce(loss_scale, group=process_group) - loss = loss_sum / loss_scale - return loss - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None, None - - -def reduce_sequence_parallel_loss(mean_loss, - loss_scale, - sp_group: dist.ProcessGroup = None): - if dist.get_world_size(sp_group) == 1: - return mean_loss - if sp_group is None: - # avoid bc breaking - sp_group = get_sp_group() - return _ReduceLoss.apply(mean_loss, loss_scale, sp_group) diff --git a/xtuner/_lite/parallel/sequence/rescale_loss.py b/xtuner/_lite/parallel/sequence/rescale_loss.py new file mode 100644 index 000000000..e39a7770e --- /dev/null +++ b/xtuner/_lite/parallel/sequence/rescale_loss.py @@ -0,0 +1,55 @@ +import copy + +import torch +import torch.distributed as dist + +from ..setup import get_sp_group + + +def rescale_sp_loss(loss_per_sp_rank, + labels_per_sp_rank, + sp_group: dist.ProcessGroup = None, + ignore_index=-100): + if sp_group is None: + sp_group = get_sp_group() + + if (sp_group is None) or (dist.get_world_size(sp_group) == 1): + return loss_per_sp_rank + + shift_labels = labels_per_sp_rank[..., 1:].view(-1) + active_tokens = (shift_labels != ignore_index).long().sum() + global_active_tokens = copy.deepcopy(active_tokens) + dist.all_reduce(global_active_tokens, group=sp_group) + loss_weight = active_tokens / global_active_tokens * dist.get_world_size( + group=sp_group) + + if active_tokens == 0: + # convert nan to 0 just for logging + loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank) + + return loss_per_sp_rank * loss_weight + + +def reduce_sp_loss_for_debug(loss_per_sp_rank, + labels_per_sp_rank, + sp_group: dist.ProcessGroup = None, + ignore_index=-100): + # Reduce loss to check whether the training losses is different + # when using sp. This function is only used for debugging + if sp_group is None: + sp_group = get_sp_group() + + if (sp_group is None) or (dist.get_world_size(sp_group) == 1): + return loss_per_sp_rank + + shift_labels = labels_per_sp_rank[..., 1:].view(-1) + active_tokens = (shift_labels != ignore_index).long().sum() + if active_tokens == 0: + # convert nan to 0 just for logging + loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank) + + loss_sum = loss_per_sp_rank * active_tokens + global_active_tokens = copy.deepcopy(active_tokens) + dist.all_reduce(loss_sum, group=sp_group) + dist.all_reduce(global_active_tokens, group=sp_group) + return loss_sum / global_active_tokens