Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions tools/fsdp_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler,
get_dp_mesh, get_dp_world_size,
get_sp_group, get_sp_mesh,
get_sp_world_size,
reduce_sequence_parallel_loss,
setup_parallel, split_for_sequence_parallel)
get_sp_world_size, reduce_sp_loss_for_debug,
rescale_sp_loss, setup_parallel,
split_for_sequence_parallel)
from xtuner._lite.parallel.fsdp import (RECOMPUTE_MODULES, LoadWoInit,
all_required_grad_wrap_policy,
checkpoint_check_fn, dp_lazy_init,
Expand Down Expand Up @@ -729,9 +729,12 @@ def warmup_fn(x):

loss = outputs.loss
if get_sp_world_size() > 1:
tokens_cal_loss = (labels != -100).sum()
loss = reduce_sequence_parallel_loss(
loss, tokens_cal_loss, sp_group)
loss = rescale_sp_loss(outputs.loss, labels, sp_group)

if args.debug:
loss_debug = reduce_sp_loss_for_debug(
outputs.loss, labels, sp_group)
logger.info(loss_debug)

avg_iter_loss = loss / iters_per_step

Expand Down
4 changes: 2 additions & 2 deletions xtuner/_lite/parallel/sequence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
pad_for_sequence_parallel)
from .ops import (gather_for_sequence_parallel, gather_forward_split_backward,
split_for_sequence_parallel, split_forward_gather_backward)
from .reduce_loss import reduce_sequence_parallel_loss
from .rescale_loss import reduce_sp_loss_for_debug, rescale_sp_loss

__all__ = [
'sequence_parallel_wrapper', 'pre_process_for_sequence_parallel_attn',
'post_process_for_sequence_parallel_attn', 'split_for_sequence_parallel',
'init_dist', 'gather_for_sequence_parallel',
'split_forward_gather_backward', 'gather_forward_split_backward',
'pad_cumulative_len_for_sequence_parallel', 'pad_for_sequence_parallel',
'reduce_sequence_parallel_loss'
'rescale_sp_loss', 'reduce_sp_loss_for_debug'
]
34 changes: 0 additions & 34 deletions xtuner/_lite/parallel/sequence/reduce_loss.py

This file was deleted.

55 changes: 55 additions & 0 deletions xtuner/_lite/parallel/sequence/rescale_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import copy

import torch
import torch.distributed as dist

from ..setup import get_sp_group


def rescale_sp_loss(loss_per_sp_rank,
labels_per_sp_rank,
sp_group: dist.ProcessGroup = None,
ignore_index=-100):
if sp_group is None:
sp_group = get_sp_group()

if (sp_group is None) or (dist.get_world_size(sp_group) == 1):
return loss_per_sp_rank

shift_labels = labels_per_sp_rank[..., 1:].view(-1)
active_tokens = (shift_labels != ignore_index).long().sum()
global_active_tokens = copy.deepcopy(active_tokens)
dist.all_reduce(global_active_tokens, group=sp_group)
loss_weight = active_tokens / global_active_tokens * dist.get_world_size(
group=sp_group)

if active_tokens == 0:
# convert nan to 0 just for logging
loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank)

return loss_per_sp_rank * loss_weight


def reduce_sp_loss_for_debug(loss_per_sp_rank,
labels_per_sp_rank,
sp_group: dist.ProcessGroup = None,
ignore_index=-100):
# Reduce loss to check whether the training losses is different
# when using sp. This function is only used for debugging
if sp_group is None:
sp_group = get_sp_group()

if (sp_group is None) or (dist.get_world_size(sp_group) == 1):
return loss_per_sp_rank

shift_labels = labels_per_sp_rank[..., 1:].view(-1)
active_tokens = (shift_labels != ignore_index).long().sum()
if active_tokens == 0:
# convert nan to 0 just for logging
loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank)

loss_sum = loss_per_sp_rank * active_tokens
global_active_tokens = copy.deepcopy(active_tokens)
dist.all_reduce(loss_sum, group=sp_group)
dist.all_reduce(global_active_tokens, group=sp_group)
return loss_sum / global_active_tokens