From cbe91f46a81ddc2d1259e44969521320b29aee57 Mon Sep 17 00:00:00 2001 From: Dmitry Date: Mon, 28 Apr 2025 10:14:52 +0000 Subject: [PATCH 1/2] CCE and CCE_minus loss added --- kernels/__init__.py | 0 kernels/cut_cross_entropy/__init__.py | 12 + kernels/cut_cross_entropy/cce.py | 205 ++++++ kernels/cut_cross_entropy/cce_backward.py | 670 ++++++++++++++++++ kernels/cut_cross_entropy/cce_lse_forward.py | 371 ++++++++++ kernels/cut_cross_entropy/constants.py | 2 + kernels/cut_cross_entropy/doc.py | 58 ++ kernels/cut_cross_entropy/indexed_dot.py | 158 +++++ .../cut_cross_entropy/linear_cross_entropy.py | 121 ++++ kernels/cut_cross_entropy/tl_autotune.py | 596 ++++++++++++++++ kernels/cut_cross_entropy/tl_utils.py | 90 +++ kernels/cut_cross_entropy/torch_compile.py | 82 +++ kernels/cut_cross_entropy/utils.py | 55 ++ .../fused_linear_cross_entropy/__init__.py | 1 + .../fused_linear_ce_loss.py | 543 ++++++++++++++ .../nn/sequential/bert4rec/lightning.py | 118 +++ .../models/nn/sequential/sasrec/lightning.py | 124 ++++ 17 files changed, 3206 insertions(+) create mode 100644 kernels/__init__.py create mode 100644 kernels/cut_cross_entropy/__init__.py create mode 100644 kernels/cut_cross_entropy/cce.py create mode 100644 kernels/cut_cross_entropy/cce_backward.py create mode 100644 kernels/cut_cross_entropy/cce_lse_forward.py create mode 100644 kernels/cut_cross_entropy/constants.py create mode 100644 kernels/cut_cross_entropy/doc.py create mode 100644 kernels/cut_cross_entropy/indexed_dot.py create mode 100644 kernels/cut_cross_entropy/linear_cross_entropy.py create mode 100644 kernels/cut_cross_entropy/tl_autotune.py create mode 100644 kernels/cut_cross_entropy/tl_utils.py create mode 100644 kernels/cut_cross_entropy/torch_compile.py create mode 100644 kernels/cut_cross_entropy/utils.py create mode 100644 kernels/fused_linear_cross_entropy/__init__.py create mode 100644 kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py diff --git a/kernels/__init__.py b/kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kernels/cut_cross_entropy/__init__.py b/kernels/cut_cross_entropy/__init__.py new file mode 100644 index 000000000..046057d13 --- /dev/null +++ b/kernels/cut_cross_entropy/__init__.py @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +from cut_cross_entropy.linear_cross_entropy import ( + LinearCrossEntropy, + LinearCrossEntropyImpl, + linear_cross_entropy, +) + +__all__ = [ + "LinearCrossEntropy", + "LinearCrossEntropyImpl", + "linear_cross_entropy", +] \ No newline at end of file diff --git a/kernels/cut_cross_entropy/cce.py b/kernels/cut_cross_entropy/cce.py new file mode 100644 index 000000000..1185511ed --- /dev/null +++ b/kernels/cut_cross_entropy/cce.py @@ -0,0 +1,205 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# This software includes modifications +from dataclasses import dataclass +from typing import cast + +import torch + +from cut_cross_entropy.cce_backward import cce_backward_kernel +from cut_cross_entropy.cce_lse_forward import cce_lse_forward_kernel +from cut_cross_entropy.constants import IGNORE_INDEX +from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start +from cut_cross_entropy.indexed_dot import indexed_neg_dot_forward_kernel +from cut_cross_entropy.utils import ( + _build_flat_valids, + _handle_eps, + handle_reduction_none, +) + + +@dataclass +class CCEParams: + targets: torch.Tensor + valids: torch.Tensor | None + softcap: float | None + reduction: str + filter_eps: float | None + shift: int + batch_shape: torch.Size + use_kahan: bool + item_inds: torch.Tensor | None + + +@torch.compile(fullgraph=True, dynamic=True) +def sort_logit_avg(logit_avg: torch.Tensor) -> torch.Tensor: + return torch.argsort(logit_avg).to(torch.int32) + + +class LinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None, + params: CCEParams, + ) -> torch.Tensor: + needs_grad = e.requires_grad or c.requires_grad + return_logit_avg = needs_grad and params.filter_eps is not None + + ret = cce_lse_forward_kernel( + e=e, + c=c, + bias=bias, + valids=params.valids, + softcap=params.softcap, + return_logit_avg=return_logit_avg, + item_inds=params.item_inds + ) + if return_logit_avg: + assert isinstance(ret, tuple) + lse, logit_avg = ret + else: + assert isinstance(ret, torch.Tensor) + lse = ret + logit_avg = None + + neg_dot = indexed_neg_dot_forward_kernel( + e=e, + c=c, + inds=params.targets, + bias=bias, + shift=params.shift, + valids=params.valids, + softcap=params.softcap, + out_dtype=lse.dtype, + ) + + nll = neg_dot.add_(lse) + + reduction = params.reduction + if reduction == "mean": + loss = nll.mean() + elif reduction == "sum": + loss = nll.sum() + elif reduction == "none": + loss = handle_reduction_none(params.batch_shape, params.valids, params.shift, nll) + else: + raise ValueError(f"Unknown reduction {reduction}") + + ctx.save_for_backward(e, c, bias, lse, params.targets, params.valids, logit_avg) + ctx.params = params + + return loss + + @staticmethod + def backward( + ctx, grad_out: torch.Tensor + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, None]: + e, c, bias, lse, targets, valids, logit_avg = ctx.saved_tensors + + if logit_avg is not None: + vocab_ordering = sort_logit_avg(logit_avg) + else: + vocab_ordering = None + + params = cast(CCEParams, ctx.params) + reduction = params.reduction + if reduction == "mean": + grad_scale = 1 / lse.numel() + elif reduction == "sum": + grad_scale = 1.0 + elif reduction == "none": + grad_scale = 1.0 + grad_out = grad_out.view(-1) + else: + raise ValueError(f"Unknown reduction {reduction}") + + de, dc, dbias = cce_backward_kernel( + do=grad_out, + e=e, + c=c, + bias=bias, + lse=lse, + valids=valids, + softcap=params.softcap, + filter_eps=params.filter_eps, + targets=targets, + shift=params.shift, + vocab_ordering=vocab_ordering, + grad_scale=grad_scale, + use_kahan=params.use_kahan, + item_inds=params.item_inds + ) + + return de, dc, dbias, None + + +def linear_cross_entropy_apply( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None, + params: CCEParams, +) -> torch.Tensor: + loss = LinearCrossEntropyFunction.apply(e, c, bias, params) + assert isinstance(loss, torch.Tensor) + + if params.shift != 0 and params.reduction == "none": + loss = loss[..., params.shift :] + + return loss + + +@add_doc_start(LINEAR_CROSS_ENTROPY_DOC) +@add_doc_start(*(doc_str + "\n" for doc_str in CCE_OPTS_DOC)) +def cce_linear_cross_entropy( + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + ignore_index: int = IGNORE_INDEX, + softcap: float | None = None, + reduction: str = "mean", + shift: bool | int = 0, + filter_eps: float | str | None = "auto", + use_kahan: bool = False, +) -> torch.Tensor: + assert e.size()[0:-1] == targets.size() + assert e.size(-1) == c.size(1) + if not torch.cuda.is_bf16_supported(): + raise RuntimeError( + "Cut Cross Entropy requires an ampere GPU or newer. " + "Consider using torch_compile_linear_cross_entropy for scenarios where one is not available." + ) + + batch_shape = targets.size() + + e = e.contiguous() + targets = targets.contiguous() + + shift = int(shift) + valids = _build_flat_valids(targets, ignore_index, shift) + + e = e.flatten(0, -2) + targets = targets.flatten() + + if (targets.data_ptr() % 16) != 0: + targets = torch.nn.functional.pad(targets, (0, 1))[:-1] + + assert (targets.data_ptr() % 16) == 0 + + return linear_cross_entropy_apply( + e, + c, + bias, + CCEParams( + targets, + valids, + softcap, + reduction, + _handle_eps(filter_eps, e.dtype), + shift, + batch_shape, + use_kahan, + ), + ) \ No newline at end of file diff --git a/kernels/cut_cross_entropy/cce_backward.py b/kernels/cut_cross_entropy/cce_backward.py new file mode 100644 index 000000000..4f829f954 --- /dev/null +++ b/kernels/cut_cross_entropy/cce_backward.py @@ -0,0 +1,670 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# This software includes modifications +import torch +import triton +import triton.language as tl + +from cut_cross_entropy.tl_autotune import cce_backward_autotune, cce_sampled_backward_autotune +from cut_cross_entropy.tl_utils import ( + b_bin_fn, + tl_and_reduce_fn, + tl_lock_add, + tl_lock_kahan_sum, + tl_softcapping, + tl_softcapping_grad, +) + + +@triton.jit +def _mm_backward( + do, + da_ptrs, + dac_ptrs, + partial_mask_a, + da_lock_ptr, + n_locks, + b_ptrs, + partial_mask_b, + stride_ad, + stride_bd, + D, + BLOCK_D: tl.constexpr, + EVEN_D: tl.constexpr, + USE_KAHAN: tl.constexpr, +): + d_inds = tl.arange(0, BLOCK_D)[None, :] + + b_ptrs = b_ptrs + d_inds * stride_bd + da_ptrs = da_ptrs + d_inds * stride_ad + if USE_KAHAN: + dac_ptrs = dac_ptrs + d_inds * stride_ad + + for d in range(0, tl.cdiv(D, BLOCK_D)): + if EVEN_D: + mask = partial_mask_b + else: + mask = partial_mask_b & (d_inds < (D - d * BLOCK_D)) + + b = tl.load(b_ptrs, mask=mask, other=0.0) + + da_i = tl.dot(do, b).to(da_ptrs.dtype.element_ty) + + if EVEN_D: + mask = partial_mask_a + else: + mask = partial_mask_a & (d_inds < (D - d * BLOCK_D)) + + lock_offset = d // tl.cdiv(D, BLOCK_D * n_locks) + this_da_lock_ptr = da_lock_ptr + lock_offset + + if USE_KAHAN: + tl_lock_kahan_sum(da_ptrs, dac_ptrs, da_i, mask, this_da_lock_ptr) + else: + tl_lock_add(da_ptrs, da_i, mask, this_da_lock_ptr) + + b_ptrs += BLOCK_D * stride_bd + da_ptrs += BLOCK_D * stride_ad + if USE_KAHAN: + dac_ptrs += BLOCK_D * stride_ad + + +@triton.jit +def _block_is_filtered(check_val: tl.tensor, filter_eps: tl.tensor) -> tl.tensor: + return tl.reduce(check_val < filter_eps, None, tl_and_reduce_fn) + + +def _cce_backward_kernel( + E, + C, + Bias, + LSE, + dOut, + grad_scale, + Valids, + VocabOrdering, + softcap, + Targets, + dE, + dEC, + dELocks, + dC, + dCC, + dCLocks, + dBias, + B, + D, + V, + BMax, + n_de_locks_0, + n_de_locks_1, + n_dc_locks_0, + n_dc_locks_1, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_biasv, + stride_vb, + filter_eps, + shift, + B_BIN, + BLOCK_B: tl.constexpr, + BLOCK_V: tl.constexpr, + BLOCK_D: tl.constexpr, + MM_BACK_BLOCK_D: tl.constexpr, + GROUP_B: tl.constexpr, + EVEN_D: tl.constexpr, + MM_BACK_EVEN_D: tl.constexpr, + ITEM_DO: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + HAS_VOCAB_ORDERING: tl.constexpr, + FILTER_GRAD: tl.constexpr, + HAS_TARGETS: tl.constexpr, + HAS_SOFTCAP: tl.constexpr, + HAS_SHIFT: tl.constexpr, + USE_KAHAN: tl.constexpr, + COMPUTE_DC: tl.constexpr, + COMPUTE_DE: tl.constexpr, + COMPUTE_DBIAS: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_b_chunks = tl.cdiv(B, BLOCK_B) + num_v_chunks = tl.cdiv(V, BLOCK_V) + num_v_in_group = GROUP_B * num_v_chunks + group_id = pid // num_v_in_group + first_pid_b = group_id * GROUP_B + group_size_b = min(num_b_chunks - first_pid_b, GROUP_B) + pid_b = first_pid_b + ((pid % num_v_in_group) % group_size_b) + pid_v = (pid % num_v_in_group) // group_size_b + + offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + if HAS_VALIDS: + offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax) + + offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V) + if HAS_VOCAB_ORDERING: + offs_v = tl.load(VocabOrdering + offs_v, mask=offs_v < V, other=V) + + offs_d = tl.arange(0, BLOCK_D) + e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd) + + accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32) + for d in range(0, tl.cdiv(D, BLOCK_D)): + e_mask = offs_b[:, None] < BMax + if not EVEN_D: + e_mask = e_mask & (offs_d[None, :] < (D - d * BLOCK_D)) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + + c_mask = offs_v[None, :] < V + if not EVEN_D: + c_mask = c_mask & (offs_d[:, None] < (D - d * BLOCK_D)) + + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + accum = tl.dot(e, c, accum) + + e_ptrs += BLOCK_D * stride_ed + c_ptrs += BLOCK_D * stride_cd + + tl.debug_barrier() + + if HAS_BIAS: + bias = tl.load(Bias + offs_v * stride_biasv, mask=offs_v < V, other=0.0) + bias = bias.to(dtype=accum.dtype) + accum += bias[None, :] + + if HAS_SOFTCAP: + accum = tl_softcapping(accum, softcap) + + if HAS_VALIDS: + direct_offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + lse = tl.load(LSE + direct_offs_b, mask=direct_offs_b < B, other=float("inf")) + else: + lse = tl.load(LSE + offs_b, mask=offs_b < B, other=float("inf")) + + d_accum = tl.exp(accum - lse[:, None]) + d_accum = tl.where(offs_v[None, :] < V, d_accum, 0.0) + + if HAS_TARGETS: + if HAS_SHIFT: + target_offs_b = offs_b + shift + else: + target_offs_b = offs_b + + targets = tl.load(Targets + target_offs_b, mask=target_offs_b < BMax, other=V + 1) + is_target = targets[:, None] == offs_v[None, :] + d_accum += tl.where(is_target, -1.0, 0.0) + else: + is_target = None + + if FILTER_GRAD: + if _block_is_filtered(tl.abs(d_accum), filter_eps): + return + + if HAS_SOFTCAP: + d_accum = tl_softcapping_grad(d_accum, accum, softcap) + + if ITEM_DO: + d_out = tl.load(dOut) + else: + if HAS_SHIFT: + d_out_offs_b = offs_b + shift + else: + d_out_offs_b = offs_b + + d_out = tl.load(dOut + d_out_offs_b, mask=d_out_offs_b < BMax, other=0.0)[:, None] + + d_out = grad_scale * d_out + + d_accum = d_accum * d_out + + if COMPUTE_DBIAS: + tl.atomic_add(dBias + offs_v * stride_biasv, tl.sum(d_accum, 0), mask=offs_v < V) + + d_accum = d_accum.to(e_ptrs.dtype.element_ty) + + if COMPUTE_DE: + lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1 + + _mm_backward( + d_accum, + dE + (offs_b[:, None] * stride_eb), + dEC + (offs_b[:, None] * stride_eb) if USE_KAHAN else None, + offs_b[:, None] < BMax, + dELocks + lock_offset, + n_de_locks_1, + C + offs_v[:, None] * stride_cv, + offs_v[:, None] < V, + stride_ed, + stride_cd, + D, + MM_BACK_BLOCK_D, + MM_BACK_EVEN_D, + USE_KAHAN, + ) + + if COMPUTE_DC: + lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1 + + _mm_backward( + tl.trans(d_accum), + dC + (offs_v[:, None] * stride_cv), + dCC + (offs_v[:, None] * stride_cv) if USE_KAHAN else None, + offs_v[:, None] < V, + dCLocks + lock_offset, + n_dc_locks_1, + E + (offs_b[:, None] * stride_eb), + offs_b[:, None] < BMax, + stride_cd, + stride_ed, + D, + MM_BACK_BLOCK_D, + MM_BACK_EVEN_D, + USE_KAHAN, + ) + + +def _cce_back_block_d(args) -> int: + block_d = args["BLOCK_D"] + return 2 * block_d + + +_cce_backward_kernel = triton.jit(_cce_backward_kernel) +_cce_backward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0, + "MM_BACK_BLOCK_D": lambda args: _cce_back_block_d(args), + "MM_BACK_EVEN_D": lambda args: (args["D"] % _cce_back_block_d(args)) == 0, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None, + "FILTER_GRAD": lambda args: args["filter_eps"] is not None, + "HAS_TARGETS": lambda args: args["Targets"] is not None, + "HAS_SOFTCAP": lambda args: args["softcap"] is not None, + "HAS_SHIFT": lambda args: args["shift"] != 0, + "ITEM_DO": lambda args: args["dOut"].numel() == 1, + "GROUP_B": lambda args: 8, + "COMPUTE_DC": lambda args: args["dC"] is not None, + "COMPUTE_DE": lambda args: args["dE"] is not None, + "COMPUTE_DBIAS": lambda args: args["dBias"] is not None, + } +)(_cce_backward_kernel) +_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel) # type: ignore + + +def _cce_sampled_backward_kernel( + E, + C, + Inds, + Bias, + LSE, + dOut, + grad_scale, + Valids, + VocabOrdering, + softcap, + Targets, + dE, + dEC, + dELocks, + dC, + dCC, + dCLocks, + dBias, + B, + D, + V, + SAMPLE_NUMS, + BMax, + n_de_locks_0, + n_de_locks_1, + n_dc_locks_0, + n_dc_locks_1, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_ib, + stride_is, + stride_biasv, + stride_vb, + filter_eps, + shift, + B_BIN, + BLOCK_B: tl.constexpr, + BLOCK_V: tl.constexpr, + BLOCK_D: tl.constexpr, + MM_BACK_BLOCK_D: tl.constexpr, + GROUP_B: tl.constexpr, + EVEN_D: tl.constexpr, + MM_BACK_EVEN_D: tl.constexpr, + ITEM_DO: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + HAS_VOCAB_ORDERING: tl.constexpr, + FILTER_GRAD: tl.constexpr, + HAS_TARGETS: tl.constexpr, + HAS_SOFTCAP: tl.constexpr, + HAS_SHIFT: tl.constexpr, + USE_KAHAN: tl.constexpr, + COMPUTE_DC: tl.constexpr, + COMPUTE_DE: tl.constexpr, + COMPUTE_DBIAS: tl.constexpr, +): + pid = tl.program_id(axis=0) + idx = tl.program_id(axis=1) + offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B) + offs_d = tl.arange(0, BLOCK_D) + + # de_accum = tl.zeros((BLOCK_B, BLOCK_D), dtype=tl.float16) + # dc_accum = tl.zeros((BLOCK_B, BLOCK_D), dtype=tl.float16) + # inds_accum = tl.zeros((BLOCK_B, ), dtype=tl.int32) + # for idx in range(0, SAMPLE_NUMS): + e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + e_mask = (offs_b[:, None] < BMax) & (offs_d[None, :] < D) + + inds_ptrs = Inds + offs_b * stride_ib + idx + inds_mask = offs_b < BMax + inds = tl.load(inds_ptrs, mask=inds_mask, other=V) + c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + c_mask = (inds[:, None] < V) & (offs_d[None, :] < D) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + dot_sum = tl.sum(e.to(tl.float32) * c.to(tl.float32), axis=1) + + if idx > 0: + dot_sum += tl.log(V - 1.0) + dot_sum -= tl.log(1.0 * SAMPLE_NUMS) + + lse = tl.load(LSE + offs_b, mask=offs_b < B, other=float("inf")) + + d_accum = tl.exp(dot_sum - lse) + d_accum = tl.where(inds < V, d_accum, 0.0) + + if HAS_TARGETS: + if HAS_SHIFT: + target_offs_b = offs_b + shift + else: + target_offs_b = offs_b + + targets = tl.load(Targets + target_offs_b, mask=target_offs_b < BMax, other=V + 1) + is_target = targets == idx + d_accum += tl.where(is_target, -1.0, 0.0) + else: + is_target = None + + + if ITEM_DO: + d_out = tl.load(dOut) + else: + if HAS_SHIFT: + d_out_offs_b = offs_b + shift + else: + d_out_offs_b = offs_b + + d_out = tl.load(dOut + d_out_offs_b, mask=d_out_offs_b < BMax, other=0.0) + + d_out = grad_scale * d_out + d_accum = d_accum * d_out + d_accum = d_accum.to(e_ptrs.dtype.element_ty) + + if COMPUTE_DE: + de_ptrs = dE + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + de_mask = (offs_b[:, None] < BMax) & (offs_d[None, :] < D) + de_out = d_accum[:, None] * c + tl.atomic_add(de_ptrs, de_out, mask=de_mask) + + if COMPUTE_DC: + dc_ptrs = dC + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + dc_mask = (inds[:, None] < V) & (offs_d[None, :] < D) + dc_out = d_accum[:, None] * e + tl.atomic_add(dc_ptrs, dc_out, mask=dc_mask) + + # if COMPUTE_DE: + # de_accum += d_accum[:, None] * c + + # if COMPUTE_DC: + # dc_accum += d_accum[:, None] * e + + + # if COMPUTE_DE: + # de_ptrs = dE + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + # de_mask = (offs_b[:, None] < BMax) & (offs_d[None, :] < D) + # tl.store(de_ptrs, de_accum, mask=de_mask) + + # if COMPUTE_DC: + # dc_ptrs = dC + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + # dc_mask = (inds[:, None] < V) & (offs_d[None, :] < D) + # tl.store(dc_ptrs, dc_accum, mask=dc_mask) + + + +_cce_sampled_backward_kernel = triton.jit(_cce_sampled_backward_kernel) +_cce_sampled_backward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0, + "MM_BACK_BLOCK_D": lambda args: _cce_back_block_d(args), + "MM_BACK_EVEN_D": lambda args: (args["D"] % _cce_back_block_d(args)) == 0, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None, + "FILTER_GRAD": lambda args: args["filter_eps"] is not None, + "HAS_TARGETS": lambda args: args["Targets"] is not None, + "HAS_SOFTCAP": lambda args: args["softcap"] is not None, + "HAS_SHIFT": lambda args: args["shift"] != 0, + "ITEM_DO": lambda args: args["dOut"].numel() == 1, + "GROUP_B": lambda args: 8, + "COMPUTE_DC": lambda args: args["dC"] is not None, + "COMPUTE_DE": lambda args: args["dE"] is not None, + "COMPUTE_DBIAS": lambda args: args["dBias"] is not None, + } +)(_cce_sampled_backward_kernel) +_cce_sampled_backward_kernel = cce_sampled_backward_autotune()(_cce_sampled_backward_kernel) # type: ignore + +def cce_backward_kernel( + do: torch.Tensor, + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None, + lse: torch.Tensor, + valids: torch.Tensor | None, + softcap: float | None, + filter_eps: float | None, + targets: torch.Tensor | None = None, + shift: int = 0, + vocab_ordering: torch.Tensor | None = None, + grad_scale: float = 1.0, + use_kahan: bool = False, + item_inds: torch.Tensor | None = None, +) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + assert do.numel() in (e.size(0), 1) + assert c.size(1) == e.size(1) + assert lse.size(0) == e.size(0) or (valids is not None and lse.size(0) == valids.size(0)) + assert e.dtype in ( + torch.float16, + torch.bfloat16, + ), "Backwards requires embeddings to be bf16 or fp16" + assert c.dtype in ( + torch.float16, + torch.bfloat16, + ), "Backwards requires classifier to be bf16 or fp16" + + do = do.contiguous() + lse = lse.contiguous() + + de = torch.zeros_like(e) if e.requires_grad else None + dc = torch.zeros_like(c) if c.requires_grad else None + + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) if bias.requires_grad else None + else: + dbias = None + + if de is not None: + assert de.stride() == e.stride() + + if dc is not None: + assert dc.stride() == c.stride() + + if dbias is not None: + assert bias is not None + assert dbias.stride() == bias.stride() + + if use_kahan: + dec = torch.zeros_like(e) if de is not None else None + dcc = torch.zeros_like(c) if dc is not None else None + else: + dec = None + dcc = None + + if dec is not None: + assert dec.stride() == e.stride() + + if dcc is not None: + assert dcc.stride() == e.stride() + + if valids is not None: + assert valids.ndim == 1 + B = valids.size(0) + else: + B = e.size(0) + + if do.numel() > 1: + do = do.contiguous() + lse = lse.contiguous() + assert do.stride(0) == lse.stride(0), f"{do.stride()=}, {lse.stride()=}" + + if item_inds is None: + def grid(META): + return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(c.size(0), META["BLOCK_V"]),) + + if vocab_ordering is not None: + assert vocab_ordering.ndim == 1 + assert vocab_ordering.numel() == c.size(0) + assert vocab_ordering.stride(0) == 1 + + nd_locks = triton.cdiv(c.size(1), 64) + if de is not None: + de_locks = e.new_zeros((triton.cdiv(B, 128), nd_locks), dtype=torch.int32) + de_lock_sizes = de_locks.size() + else: + de_locks = None + de_lock_sizes = (None, None) + + if dc is not None: + dc_locks = c.new_zeros((triton.cdiv(c.size(0), 128), nd_locks), dtype=torch.int32) + dc_lock_sizes = dc_locks.size() + else: + dc_locks = None + dc_lock_sizes = (None, None) + + _cce_backward_kernel[grid]( + e, + c, + bias, + lse, + do, + grad_scale, + valids, + vocab_ordering, + softcap, + targets, + de, + dec, + de_locks, + dc, + dcc, + dc_locks, + dbias, + B, + e.size(1), + c.size(0), + e.size(0), + *de_lock_sizes, + *dc_lock_sizes, + e.stride(0), + e.stride(1), + c.stride(0), + c.stride(1), + 1 if bias is None else bias.stride(0), + 1 if valids is None else valids.stride(0), + filter_eps, + shift=shift, + B_BIN=b_bin_fn(B), + USE_KAHAN=use_kahan, + ) + else: + SAMPLE_NUMS = item_inds.size(1) + def grid(META): + return (triton.cdiv(B, META["BLOCK_B"]), SAMPLE_NUMS) + D = e.size(1) + BLOCK_D = int(2**torch.ceil(torch.log2(torch.tensor(D)))) + + # nd_locks = triton.cdiv(c.size(1), 64) + # if de is not None: + # de_locks = e.new_zeros((triton.cdiv(B, 128), nd_locks), dtype=torch.int32) + # de_lock_sizes = de_locks.size() + # else: + # de_locks = None + # de_lock_sizes = (None, None) + + # if dc is not None: + # dc_locks = c.new_zeros((triton.cdiv(c.size(0), 128), nd_locks), dtype=torch.int32) + # dc_lock_sizes = dc_locks.size() + # else: + # dc_locks = None + # dc_lock_sizes = (None, None) + + targets_cce_sampled_loss = torch.zeros_like(lse) + + _cce_sampled_backward_kernel[grid]( + e, + c, + item_inds, + bias, + lse, + do, + grad_scale, + valids, + vocab_ordering, + softcap, + targets_cce_sampled_loss, + de, + None, #dec, + None,#de_locks, + dc, + None, #dcc, + None, #dc_locks, + dbias, + B, + D, + c.size(0), + SAMPLE_NUMS, + e.size(0), + *(None, None), #*de_lock_sizes, + *(None, None), #*dc_lock_sizes, + e.stride(0), + e.stride(1), + c.stride(0), + c.stride(1), + item_inds.stride(0), + item_inds.stride(1), + 1 if bias is None else bias.stride(0), + 1 if valids is None else valids.stride(0), + filter_eps, + shift=shift, + B_BIN=b_bin_fn(B), + USE_KAHAN=use_kahan, + BLOCK_D=BLOCK_D + ) + + if dbias is not None: + assert bias is not None + dbias = dbias.to(dtype=bias.dtype) + + return de, dc, dbias \ No newline at end of file diff --git a/kernels/cut_cross_entropy/cce_lse_forward.py b/kernels/cut_cross_entropy/cce_lse_forward.py new file mode 100644 index 000000000..ef4cbaed5 --- /dev/null +++ b/kernels/cut_cross_entropy/cce_lse_forward.py @@ -0,0 +1,371 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# This software includes modifications +from typing import Literal, overload + +import torch +import triton +import triton.language as tl + +from cut_cross_entropy.tl_autotune import cce_forward_autotune, cce_sampled_forward_autotune +from cut_cross_entropy.tl_utils import b_bin_fn, tl_logaddexp, tl_softcapping + + +def _cce_lse_forward_kernel( + E, + C, + Bias, + LSE, + LA, + Locks, + Valids, + softcap, + B, + V, + D, + BMax, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_biasv, + stride_lse_b, + stride_vb, + num_locks, + # Meta-parameters + B_BIN, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_V: tl.constexpr, + BLOCK_D: tl.constexpr, # + GROUP_B: tl.constexpr, # + EVEN_D: tl.constexpr, + HAS_SOFTCAP: tl.constexpr, + HAS_LA: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_b = tl.cdiv(B, BLOCK_B) + num_pid_v = tl.cdiv(V, BLOCK_V) + num_pid_in_group = GROUP_B * num_pid_v + group_id = pid // num_pid_in_group + first_pid_b = group_id * GROUP_B + group_size_b = min(num_pid_b - first_pid_b, GROUP_B) + pid_b = first_pid_b + ((pid % num_pid_in_group) % group_size_b) + pid_v = (pid % num_pid_in_group) // group_size_b + + offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + if HAS_VALIDS: + offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax) + + offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V) + offs_d = tl.arange(0, BLOCK_D) + e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd) + + accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32) + for d in range(0, tl.cdiv(D, BLOCK_D)): + e_mask = offs_b[:, None] < BMax + if not EVEN_D: + e_mask = e_mask & (offs_d[None, :] < (D - d * BLOCK_D)) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + + c_mask = offs_v[None, :] < V + if not EVEN_D: + c_mask = c_mask & (offs_d[:, None] < (D - d * BLOCK_D)) + + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION) + + e_ptrs += BLOCK_D * stride_ed + c_ptrs += BLOCK_D * stride_cd + + tl.debug_barrier() + + if HAS_BIAS: + bias = tl.load(Bias + offs_v * stride_biasv, mask=offs_v < V, other=0.0) + bias = bias.to(dtype=accum.dtype) + accum += bias[None, :] + + logits = tl.where(offs_v[None, :] < V, accum, -float("inf")) + if HAS_SOFTCAP: + logits = tl_softcapping(logits, softcap) + + if HAS_LA: + this_avg_logit = tl.sum(logits, 0) / B + tl.atomic_add(LA + offs_v, this_avg_logit, mask=offs_v < V) + + this_mx = tl.max(logits, axis=1) + e = tl.exp(logits - this_mx[:, None]) + this_lse = this_mx + tl.log(tl.sum(e, axis=1)) + + offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + o_mask = offs_b < B + + lse_ptrs = LSE + (stride_lse_b * offs_b) + + this_locks = Locks + (pid_b // tl.cdiv(B, BLOCK_B * num_locks)) + while tl.atomic_cas(this_locks, 0, 1) == 1: + pass + + lse = tl.load(lse_ptrs, mask=o_mask, other=0.0, eviction_policy="evict_last") + lse = tl_logaddexp(lse, this_lse) + tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last") + + tl.debug_barrier() + tl.atomic_xchg(this_locks, 0) + + +_cce_lse_forward_kernel = triton.jit(_cce_lse_forward_kernel) +_cce_lse_forward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_SOFTCAP": lambda args: args["softcap"] is not None, + "HAS_LA": lambda args: args["LA"] is not None, + "GROUP_B": lambda args: 8, + "DOT_PRECISION": lambda args: "tf32" + if torch.get_float32_matmul_precision() == "high" + else "ieee", + } +)(_cce_lse_forward_kernel) +_cce_lse_forward_kernel = cce_forward_autotune()(_cce_lse_forward_kernel) # type: ignore + + +def _cce_lse_sampled_forward_kernel( + E, + C, + Inds, + Bias, + LSE, + LA, + Locks, + Valids, + softcap, + B, + V, + D, + SAMPLE_NUMS, + BMax, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_ib, + stride_is, + stride_biasv, + stride_lse_b, + stride_vb, + num_locks, + # Meta-parameters + B_BIN, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_V: tl.constexpr, + BLOCK_D: tl.constexpr, # + GROUP_B: tl.constexpr, # + EVEN_D: tl.constexpr, + HAS_SOFTCAP: tl.constexpr, + HAS_LA: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + pid = tl.program_id(axis=0) + m = tl.full((BLOCK_B, ), float("-inf"), dtype=tl.float32) + d = tl.zeros((BLOCK_B, ), dtype=tl.float32) + + offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B) + offs_d = tl.arange(0, BLOCK_D) + + for idx in range(0, SAMPLE_NUMS): + e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + e_mask = (offs_b[:, None] < BMax) & (offs_d[None, :] < D) + + inds_ptrs = Inds + offs_b * stride_ib + idx + inds_mask = offs_b < BMax + inds = tl.load(inds_ptrs, mask=inds_mask, other=V) + c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + c_mask = (inds[:, None] < V) & (offs_d[None, :] < D) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + dot_sum = tl.sum(e.to(tl.float32) * c.to(tl.float32), axis=1) + + if idx > 0: + dot_sum += tl.log(V - 1.0) + dot_sum -= tl.log(1.0 * SAMPLE_NUMS) + + block_max = dot_sum + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.exp(dot_sum - m_new) + m = m_new + + lse = m + tl.log(d) + lse_ptrs = LSE + offs_b + out_mask = (offs_b < BMax) + tl.store(lse_ptrs, lse, mask = out_mask) + +_cce_lse_sampled_forward_kernel = triton.jit(_cce_lse_sampled_forward_kernel) +_cce_lse_sampled_forward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_SOFTCAP": lambda args: args["softcap"] is not None, + "HAS_LA": lambda args: args["LA"] is not None, + "GROUP_B": lambda args: 8, + "DOT_PRECISION": lambda args: "tf32" + if torch.get_float32_matmul_precision() == "high" + else "ieee", + } +)(_cce_lse_sampled_forward_kernel) +_cce_lse_sampled_forward_kernel = cce_sampled_forward_autotune()(_cce_lse_sampled_forward_kernel) # type: ignore + + +@overload +def cce_lse_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None = None, + valids: torch.Tensor | None = None, + softcap: float | None = None, + return_logit_avg: Literal[False] = False, +) -> torch.Tensor: ... + + +@overload +def cce_lse_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None = None, + valids: torch.Tensor | None = None, + softcap: float | None = None, + return_logit_avg: Literal[True] = True, +) -> tuple[torch.Tensor, torch.Tensor]: ... + + +@overload +def cce_lse_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None = None, + valids: torch.Tensor | None = None, + softcap: float | None = None, + return_logit_avg: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: ... + + +def cce_lse_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None = None, + valids: torch.Tensor | None = None, + softcap: float | None = None, + return_logit_avg: bool = False, + item_inds: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: + # Check constraints. + assert e.shape[1] == c.shape[1], "Incompatible dimensions" + assert e.is_contiguous(), "Matrix A must be contiguous" + if valids is not None: + assert valids.ndim == 1 + B = valids.numel() + else: + B, _ = e.shape + + if bias is not None: + assert bias.ndim == 1 + assert c.shape[0] == bias.shape[0] + + V, D = c.shape + # Allocates output. + lse = e.new_full((B,), -float("inf"), dtype=torch.float32) + + + if item_inds is None: + locks = e.new_full( + (triton.cdiv(B, 128),), + 0, + dtype=torch.uint32, + ) + + if return_logit_avg: + logit_avg = e.new_full((V,), 0.0, dtype=torch.float32) + else: + logit_avg = None + + # 1D launch kernel where each block gets its own program. + def grid(META) -> tuple[int]: + return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(V, META["BLOCK_V"]),) + + _cce_lse_forward_kernel[grid]( + e, + c, + bias, + lse, # + logit_avg, + locks, + valids, + softcap, + B, + V, + D, # + e.size(0), + e.stride(0), + e.stride(1), # + c.stride(0), + c.stride(1), # + 1 if bias is None else bias.stride(0), + lse.stride(0), + 1 if valids is None else valids.stride(0), + num_locks=locks.size(0), + B_BIN=b_bin_fn(B), + ) + else: + SAMPLE_NUMS = item_inds.size(1) + if return_logit_avg: + logit_avg = e.new_full((SAMPLE_NUMS,), 0.0, dtype=torch.float32) + else: + logit_avg = None + # 1D launch kernel where each block gets its own program. + def grid(META) -> tuple[int]: + return (triton.cdiv(B, META['BLOCK_B']), ) + BLOCK_D = int(2**torch.ceil(torch.log2(torch.tensor(D)))) + _cce_lse_sampled_forward_kernel[grid]( + e, + c, + item_inds, + bias, + lse, # + logit_avg, + None, #locks + valids, + softcap, + B, + V, + D, # + SAMPLE_NUMS, + e.size(0), + e.stride(0), + e.stride(1), # + c.stride(0), + c.stride(1), # + item_inds.stride(0), + item_inds.stride(1), + 1 if bias is None else bias.stride(0), + lse.stride(0), + 1 if valids is None else valids.stride(0), + num_locks=None, # num_locks=locks.size(0), + B_BIN=b_bin_fn(B), + BLOCK_D=BLOCK_D + ) + + if return_logit_avg: + assert logit_avg is not None + return lse, logit_avg + else: + return lse \ No newline at end of file diff --git a/kernels/cut_cross_entropy/constants.py b/kernels/cut_cross_entropy/constants.py new file mode 100644 index 000000000..2ba4670ac --- /dev/null +++ b/kernels/cut_cross_entropy/constants.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +IGNORE_INDEX: int = -100 \ No newline at end of file diff --git a/kernels/cut_cross_entropy/doc.py b/kernels/cut_cross_entropy/doc.py new file mode 100644 index 000000000..4b3d2803a --- /dev/null +++ b/kernels/cut_cross_entropy/doc.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +LINEAR_CROSS_ENTROPY_DOC = """Computes cross-entropy loss using the logits generated by performing + the matrix multiplication between the embeddings (e) and classifier (c). + + This method saves GPU memory by not materializing the logits into GPU + main memory. + + + Specifically, this computes + + ```python + + loss = F.cross_entropy((e @ c.T).float(), targets) + ``` + + without allocating the intermediary (e @ c.T).float() matrix. + + :param e: Embedding of the inputs used to compute the logits. Shape (..., D) + :param c: Classifier matrix. Shape (NumClasses, D) + :param targets: The target class for each input. Values must be in [0, NumClasses). Shape (...) + :param ignore_index: If an input as a target of this value, it is ignored in the loss computation. + :param softcap: The value for logit softcapping. + :param reduction: The reduction to perform over the loss. Supports "mean", "sum", and "none". + :param shift: When non-zero, the embedding and targets will be shifted along the temporal axis to perform nth-next token prediction. + Specifically, this is used to efficiently compute the following + + ```python + shift_e = e[..., :-shift, :].flatten(0, -2) + shift_targets = targets[..., shift:].flatten() + + loss = F.cross_entropy((shift_e @ c.T), targets) + ``` + + If given a boolean value, False will be treated as zero and True will be treated as one. + + When this value is non-zero or True, e and targets must have shape (..., T, D) and (..., T), respectively. + + Integer values must be in [0, T) +""" + +CCE_OPTS_DOC = [ + """ + :param filter_eps: The threshold value used to determine which locations can be safely ignored + in gradient computation. The default value of "auto" will automatically choose a value + based on the input dtype.""", + """ + :param use_kahan: Uses Kahan summation to increase the precision of CCE's reduction along the vocab axis. This only + makes sense to set to True when filter_eps is None (or is a very very small value).""", +] + + +def add_doc_start(*docstr: str): + def add_doc(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + + return fn + + return add_doc \ No newline at end of file diff --git a/kernels/cut_cross_entropy/indexed_dot.py b/kernels/cut_cross_entropy/indexed_dot.py new file mode 100644 index 000000000..314d00864 --- /dev/null +++ b/kernels/cut_cross_entropy/indexed_dot.py @@ -0,0 +1,158 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import torch +import triton +import triton.language as tl + +from cut_cross_entropy.tl_autotune import indexed_dot_autotune +from cut_cross_entropy.tl_utils import b_bin_fn +from cut_cross_entropy.utils import softcapping + + +def _indexed_neg_dot_forward_kernel( + E, + C, + Inds, + Bias, + Valids, + Out, + B, + D, + V, + BMax, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_ib, + stride_biasv, + stride_vb, + shift, + B_BIN, + BLOCK_B: tl.constexpr, + BLOCK_D: tl.constexpr, + GROUP_B: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + EVEN_D: tl.constexpr, + HAS_SHIFT: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_b_chunks = tl.cdiv(B, BLOCK_B) + num_d_chunks = tl.cdiv(D, BLOCK_D) + num_d_in_group = GROUP_B * num_d_chunks + group_id = pid // num_d_in_group + first_pid_b = group_id * GROUP_B + group_size_b = min(num_b_chunks - first_pid_b, GROUP_B) + pid_b = first_pid_b + ((pid % num_d_in_group) % group_size_b) + pid_d = (pid % num_d_in_group) // group_size_b + + offs_b = tl.arange(0, BLOCK_B) + pid_b * BLOCK_B + if HAS_VALIDS: + offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax) + + offs_d = tl.arange(0, BLOCK_D) + pid_d * BLOCK_D + e_ptrs = E + (stride_eb * offs_b[:, None] + stride_ed * offs_d[None, :]) + + e_mask = offs_b[:, None] < BMax + if not EVEN_D: + e_mask = e_mask & (offs_d[None, :] < D) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + + if HAS_SHIFT: + offs_b = offs_b + shift + + inds = tl.load(Inds + stride_ib * offs_b, mask=offs_b < BMax, other=V) + + c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + + c_mask = inds[:, None] < V + if not EVEN_D: + c_mask = c_mask & (offs_d[None, :] < D) + + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + offs_b = tl.arange(0, BLOCK_B) + pid_b * BLOCK_B + out_ptrs = Out + offs_b + dot = e.to(tl.float32) * c.to(tl.float32) + neg_dot = -tl.sum(dot, 1) + + if HAS_BIAS: + bias = tl.load(Bias + inds * stride_biasv, mask=inds < V, other=0.0) + bias = bias.to(tl.float32) + neg_dot -= bias + + tl.atomic_add(out_ptrs, neg_dot.to(out_ptrs.dtype.element_ty), mask=offs_b < B) + + +_indexed_neg_dot_forward_kernel = triton.jit(_indexed_neg_dot_forward_kernel) +_indexed_neg_dot_forward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_SHIFT": lambda args: args["shift"] != 0, + "GROUP_B": lambda args: 8, + } +)(_indexed_neg_dot_forward_kernel) +_indexed_neg_dot_forward_kernel = indexed_dot_autotune()(_indexed_neg_dot_forward_kernel) # type: ignore + + +def indexed_neg_dot_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + inds: torch.Tensor, + bias: torch.Tensor | None = None, + shift: int = 0, + valids: torch.Tensor | None = None, + softcap: float | None = None, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + assert inds.ndim == 1 + assert e.ndim == 2 + assert c.ndim == 2 + assert inds.size(0) == e.size(0) + assert c.size(1) == e.size(1) + + if valids is not None: + assert valids.ndim == 1 + B = valids.size(0) + else: + B = e.size(0) + + out = e.new_zeros((B,), dtype=torch.float32) + + def grid(META) -> tuple[int]: + return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(e.size(1), META["BLOCK_D"]),) + + _indexed_neg_dot_forward_kernel[grid]( + e, + c, + inds, + bias, + valids, + out, + B, + e.size(1), + c.size(0), + e.size(0), + e.stride(0), + e.stride(1), + c.stride(0), + c.stride(1), + inds.stride(0), + 1 if bias is None else bias.stride(0), + 1 if valids is None else valids.stride(0), + shift=shift, + B_BIN=b_bin_fn(B), + ) + + if softcap is not None: + out = softcapping(out, softcap) + + if out_dtype is None: + out_dtype = e.dtype + + out = out.to(out_dtype) + + return out \ No newline at end of file diff --git a/kernels/cut_cross_entropy/linear_cross_entropy.py b/kernels/cut_cross_entropy/linear_cross_entropy.py new file mode 100644 index 000000000..ab66cd8a8 --- /dev/null +++ b/kernels/cut_cross_entropy/linear_cross_entropy.py @@ -0,0 +1,121 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# This software includes modifications +import enum +import platform +from enum import auto +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn + +from cut_cross_entropy.constants import IGNORE_INDEX +from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start +from cut_cross_entropy.torch_compile import torch_compile_linear_cross_entropy + + +class LinearCrossEntropyImpl(enum.IntEnum): + CCE = auto() + TORCH_COMPILE = auto() + CCE_EXACT = auto() + + +PLATFORM_SYSTEM = platform.system() + +if TYPE_CHECKING or PLATFORM_SYSTEM != "Darwin": + from cut_cross_entropy.cce import cce_linear_cross_entropy + + LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.CCE +else: + cce_linear_cross_entropy = None + LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.TORCH_COMPILE + + +@add_doc_start(LINEAR_CROSS_ENTROPY_DOC) +@add_doc_start(*(doc_str + " Only valid for the cce implementation.\n" for doc_str in CCE_OPTS_DOC)) +def linear_cross_entropy( + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + ignore_index: int = IGNORE_INDEX, + softcap: float | None = None, + reduction: str = "mean", + shift: bool | int = 0, + filter_eps: float | str | None = "auto", + use_kahan: bool = False, + impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, +) -> torch.Tensor: + """ + :param impl: The linear cross entropy implementation to use. Currently supports cce, torch_compile, and cce_exact. + """ + + if isinstance(impl, LinearCrossEntropyImpl): + impl = impl.name.lower() + + if isinstance(shift, int) and (shift < 0 or shift >= targets.size(-1)): + raise ValueError(f"Shift must be in the range [0, {targets.size(-1)}). Got {shift}.") + + match impl: + case "cce" | "cce_exact": + if platform.system() == "Darwin": + raise RuntimeError( + "CCE does not support MacOS. Please use torch_compile when running on MacOS instead." + ) + + if impl == "cce_exact": + filter_eps = None + use_kahan = True + + assert cce_linear_cross_entropy is not None + return cce_linear_cross_entropy( + e, c, targets, bias, ignore_index, softcap, reduction, shift, filter_eps, use_kahan + ) + case "torch_compile": + return torch_compile_linear_cross_entropy( + e, c, targets, bias, ignore_index, softcap, reduction, shift + ) + case _: + raise NotImplementedError(f"{impl} is not implemented.") + + +class LinearCrossEntropy(nn.Module): + def __init__( + self, + ignore_index: int = IGNORE_INDEX, + softcap: float | None = None, + reduction: str = "mean", + shift: bool | int = 0, + filter_eps: float | str | None = "auto", + use_kahan: bool = False, + impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, + ): + super().__init__() + self.ignore_index = ignore_index + self.softcap = softcap + self.reduction = reduction + self.filter_eps = filter_eps + self.shift = shift + self.use_kahan = use_kahan + + self.impl = impl + + def forward( + self, + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return linear_cross_entropy( + e, + c, + targets, + bias=bias, + ignore_index=self.ignore_index, + softcap=self.softcap, + reduction=self.reduction, + shift=self.shift, + filter_eps=self.filter_eps, + use_kahan=self.use_kahan, + impl=self.impl, + ) \ No newline at end of file diff --git a/kernels/cut_cross_entropy/tl_autotune.py b/kernels/cut_cross_entropy/tl_autotune.py new file mode 100644 index 000000000..452886440 --- /dev/null +++ b/kernels/cut_cross_entropy/tl_autotune.py @@ -0,0 +1,596 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# This software includes modifications +import functools +import heapq +import os +from dataclasses import dataclass, field +from typing import Any, Callable + +import torch +import triton +from triton import Config, cdiv +from triton.runtime import autotuner, driver +from triton.testing import ( + get_dram_gbps, + get_max_simd_tflops, + get_max_tensorcore_tflops, + nvsmi, +) + +_AUTOTUNE: bool = os.getenv("CCE_AUTOTUNE", "0") != "0" + + +@dataclass +class NoneSupportRestorer: + reset_idx: list[int] + restore_idx: list[int] + _restore_copies: list[torch.Tensor | None] = field(default_factory=list, init=False) + + def pre_hook(self, args: list[torch.Tensor | None | Any]) -> None: + for i in self.reset_idx: + v = args[i] + if v is not None: + assert isinstance(v, torch.Tensor) + v.zero_() + + for i in self.reset_idx: + v = args[i] + if v is not None: + assert isinstance(v, torch.Tensor) + self._restore_copies.append(v.clone()) + else: + self._restore_copies.append(None) + + def post_hook(self, args: list[torch.Tensor | None | Any], _exception) -> None: + for j, i in enumerate(self.reset_idx): + v = args[i] + if v is not None: + old_v = self._restore_copies[j] + assert isinstance(v, torch.Tensor) + assert old_v is not None + + v.copy_(old_v) + + self._restore_copies = [] + + +@functools.wraps(triton.autotune) +def _cce_autotune(*args, **kwargs) -> Callable[..., autotuner.Autotuner]: + def decorator(fn): + reset_idx = [] + restore_idx = [] + arg_names = fn.arg_names + reset_to_zero = kwargs.pop("reset_to_zero", None) + if reset_to_zero is not None: + reset_idx = [arg_names.index(k) for k in reset_to_zero] + + restore_value = kwargs.pop("restore_value", None) + if restore_value is not None: + restore_idx = [arg_names.index(k) for k in restore_value] + + restorer = NoneSupportRestorer(reset_idx, restore_idx) + if len(reset_idx) > 0: + kwargs["pre_hook"] = restorer.pre_hook + + if len(restore_idx) > 0: + kwargs["post_hook"] = restorer.post_hook + + return triton.autotune(*args, **kwargs)(fn) + + return decorator + + +@functools.lru_cache() +def get_clock_rate_in_khz(): + try: + return nvsmi(["clocks.max.sm"])[0] * 1e3 + except FileNotFoundError: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 + + +def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): + """return compute throughput in TOPS""" + total_warps = num_ctas * min(num_warps, 4) + num_subcores = ( + driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + ) # on recent GPUs + tflops = ( + min(num_subcores, total_warps) + / num_subcores + * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) + ) + return tflops + + +def get_simd_tflops(device, num_ctas, num_warps, dtype): + """return compute throughput in TOPS""" + total_warps = num_ctas * min(num_warps, 4) + num_subcores = ( + driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + ) # on recent GPUs + tflops = ( + min(num_subcores, total_warps) + / num_subcores + * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) + ) + return tflops + + +def get_tflops(device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) + + +def early_config_prune( + configs, + named_args, + *, + shared_memory_factor: float = 1.0, + max_num_warps: int | None = None, + **kwargs, +): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_B, BLOCK_V, BLOCK_D, SPLIT_K, num_warps, num_stages + dtsize = named_args["E"].element_size() + + if max_num_warps is not None: + configs = [config for config in configs if config.num_warps <= max_num_warps] + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_B, BLOCK_V, BLOCK_D, num_stages = ( + kw["BLOCK_B"], + kw["BLOCK_V"], + kw["BLOCK_D"], + config.num_stages, + ) + + max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = ( + shared_memory_factor * (BLOCK_B + BLOCK_V) * BLOCK_D * num_stages * dtsize + ) + if required_shared_memory > max_shared_memory: + continue + + pruned_configs.append(config) + + configs = pruned_configs + + # group configs by (BLOCK_B,_N,_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_B, BLOCK_V, BLOCK_D, num_warps, num_stages = ( + kw["BLOCK_B"], + kw["BLOCK_V"], + kw["BLOCK_D"], + config.num_warps, + config.num_stages, + ) + + key = (BLOCK_B, BLOCK_V, BLOCK_D, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_B, BLOCK_V, BLOCK_D, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_B * BLOCK_V * BLOCK_D / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, + v, + key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 + else x[1] - optimal_num_stages, + ) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs + + +def _total_ops_fn(B, V, D) -> float: + return 2 * B * V * D + 10 * B * V + + +def _total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v): + return B * dtsize + + +def estimate_matmul_time( + # backend, device, + num_warps, + num_stages, # + E, + B, + V, + D, # + BLOCK_B, + BLOCK_V, + BLOCK_D, + debug=False, + total_ops_fn=_total_ops_fn, + total_store_fn=_total_store_fn, + **kwargs, # +): + """return estimated running time in ms + = max(compute, loading) + store""" + device = torch.cuda.current_device() + dtype = E.dtype + dtsize = E.element_size() + + num_cta_b = cdiv(B, BLOCK_B) + num_cta_v = cdiv(V, BLOCK_V) + num_ctas = num_cta_b * num_cta_v + + # If the input is smaller than the block size + B, V = max(B, BLOCK_B), max(V, BLOCK_V) + + # time to compute + total_ops = total_ops_fn(B, V, D) + total_ops = total_ops / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(device) * ( + active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05 + ) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = B * D * dtsize * (1 + 0.2 * (num_cta_v - 1)) + load_a_l2 = B * D * dtsize * 0.8 * (num_cta_v - 1) + load_b_dram = V * D * dtsize * (1 + 0.2 * (num_cta_b - 1)) + load_b_l2 = V * D * dtsize * 0.8 * (num_cta_b - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.4 # :o + store_dram = total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v) / (1024 * 1024) + store_ms = store_dram / store_bw + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print( + f"{BLOCK_B=}, {BLOCK_V=}, {BLOCK_D=}, {num_warps=}, {num_stages=}, " + f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " + f"loading time: {load_ms}ms, store time: {store_ms}ms, " + f"Activate CTAs: {active_cta_ratio*100}%" + ) + return total_time_ms + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_B": block_m, + "BLOCK_V": block_n, + "BLOCK_D": block_k, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +def get_autotune_config(): + return [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 128}, + num_stages=2, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 32}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 64, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 64, "BLOCK_V": 256, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=4, + num_warps=8, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 64, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 64, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 32, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config({"BLOCK_B": 64, "BLOCK_V": 32, "BLOCK_D": 32}, num_stages=5, num_warps=2), + # good for int8 + Config( + {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 128}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 128}, + num_stages=3, + num_warps=16, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 128}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 128}, + num_stages=3, + num_warps=16, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 64, "BLOCK_D": 128}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 64, "BLOCK_V": 256, "BLOCK_D": 128}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 128}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 64, "BLOCK_D": 64}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 64, "BLOCK_V": 128, "BLOCK_D": 64}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 32, "BLOCK_D": 64}, + num_stages=4, + num_warps=4, + ), + Config({"BLOCK_B": 64, "BLOCK_V": 32, "BLOCK_D": 64}, num_stages=5, num_warps=2), + ] + get_configs_io_bound() + + +def _heuristics_from_config(config: Config) -> Callable[..., autotuner.Heuristics]: + return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()}) + + +def _cce_forward_best_config() -> Config: + return Config(dict(BLOCK_B=256, BLOCK_V=128, BLOCK_D=32), num_warps=8, num_stages=3) + + +def _cce_sampled_forward_best_config() -> Config: + # return Config(dict(BLOCK_B=32, BLOCK_V=128), num_warps=2, num_stages=3) + # return Config(dict(BLOCK_B=128, BLOCK_V=128), num_warps=16, num_stages=4) + return Config(dict(BLOCK_B=32, BLOCK_V=128), num_warps=16, num_stages=4) + + + +def cce_forward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=get_autotune_config(), + key=["V", "D", "B_BIN"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": 10, + }, + restore_value=["LSE"], + reset_to_zero=["LA"], + ) + else: + return _heuristics_from_config(_cce_forward_best_config()) + + +def cce_sampled_forward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=get_autotune_config(), + key=["V", "D", "B_BIN"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": 10, + }, + restore_value=["LSE"], + reset_to_zero=["LA"], + ) + else: + return _heuristics_from_config(_cce_sampled_forward_best_config()) + + +def _bw_total_ops_fn(B, V, D) -> float: + return 2 * B * V * D + 6 * B * V + 0.2 * (2 * B * V * D + 2 * B * V * D) + + +def _bw_total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v): + return 0.2 * (num_cta_v * B * D * dtsize + num_cta_b * D * V * dtsize) + + +def _cce_backward_best_config() -> Config: + return Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4) + + +def _cce_sampled_backward_best_config() -> Config: + # return Config(dict(BLOCK_B=32, BLOCK_V=128), num_warps=2, num_stages=5) + # return Config(dict(BLOCK_B=128, BLOCK_V=128), num_warps=16, num_stages=4) + return Config(dict(BLOCK_B=32, BLOCK_V=128), num_warps=16, num_stages=4) + +def cce_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=get_autotune_config(), + key=["V", "D", "B_BIN"], + prune_configs_by={ + "early_config_prune": functools.partial( + early_config_prune, shared_memory_factor=2.0 + ), + "perf_model": functools.partial( + estimate_matmul_time, + total_ops_fn=_bw_total_ops_fn, + total_store_fn=_bw_total_store_fn, + ), + "top_k": 5, + }, + reset_to_zero=["dE", "dC", "dEC", "dCC", "dBias"], + ) + else: + return _heuristics_from_config(_cce_backward_best_config()) + + +def cce_sampled_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=get_autotune_config(), + key=["V", "D", "B_BIN"], + prune_configs_by={ + "early_config_prune": functools.partial( + early_config_prune, shared_memory_factor=2.0 + ), + "perf_model": functools.partial( + estimate_matmul_time, + total_ops_fn=_bw_total_ops_fn, + total_store_fn=_bw_total_store_fn, + ), + "top_k": 5, + }, + reset_to_zero=["dE", "dC", "dEC", "dCC", "dBias"], + ) + else: + return _heuristics_from_config(_cce_sampled_backward_best_config()) + + +def _indexed_dot_best_config() -> Config: + return Config(dict(BLOCK_B=128, BLOCK_D=256), num_warps=16, num_stages=4) + + +def _indexed_dot_all_configs() -> list[Config]: + return [ + Config( + dict( + BLOCK_B=128, + BLOCK_D=128, + ), + num_warps=4, + num_stages=4, + ), + Config( + dict( + BLOCK_B=128, + BLOCK_D=128, + ), + num_warps=8, + num_stages=4, + ), + Config( + dict( + BLOCK_B=256, + BLOCK_D=256, + ), + num_warps=16, + num_stages=4, + ), + Config( + dict( + BLOCK_B=256, + BLOCK_D=128, + ), + num_warps=16, + num_stages=4, + ), + Config( + dict( + BLOCK_B=128, + BLOCK_D=256, + ), + num_warps=16, + num_stages=4, + ), + ] + + +def indexed_dot_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=_indexed_dot_all_configs(), + key=["D", "B_BIN"], + reset_to_zero=["Out"], + ) + else: + return _heuristics_from_config(_indexed_dot_best_config()) \ No newline at end of file diff --git a/kernels/cut_cross_entropy/tl_utils.py b/kernels/cut_cross_entropy/tl_utils.py new file mode 100644 index 000000000..e35411c6f --- /dev/null +++ b/kernels/cut_cross_entropy/tl_utils.py @@ -0,0 +1,90 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import triton +import triton.language as tl +from triton.language.extra import libdevice as tl_libdevice + + +@triton.jit +def tl_and_reduce_fn(a, b): + return a & b + + +@triton.jit +def tl_tanh(a: tl.tensor) -> tl.tensor: + return tl_libdevice.tanh(a) + + +@triton.jit +def tl_log1p(a: tl.tensor) -> tl.tensor: + return tl_libdevice.log1p(a) + + +@triton.jit +def tl_softcapping(v: tl.tensor, softcap: float) -> tl.tensor: + return tl_tanh(v / softcap) * softcap + + +@triton.jit +def tl_softcapping_grad(dv: tl.tensor, v: tl.tensor, softcap: float) -> tl.tensor: + v = v / softcap + return dv * (1 - v * v) + + +@triton.jit +def tl_logaddexp(a, b) -> tl.tensor: + minx = tl.minimum(a, b) + mx = tl.maximum(a, b) + return tl_log1p(tl.exp(minx - mx)) + mx + + +@triton.jit +def tl_2sum(a: tl.tensor, b: tl.tensor) -> tuple[tl.tensor, tl.tensor]: + s = a + b + + a_prime = s - b + b_prime = s - a_prime + + delta_a = a - a_prime + delta_b = b - b_prime + + t = delta_a + delta_b + return s, t + + +@triton.jit +def tl_lock_kahan_sum(ptrs, c_ptrs, v, mask, lock_ptr): + while tl.atomic_cas(lock_ptr, 0, 1) == 1: + pass + + s = tl.load(ptrs, mask=mask, other=0.0, eviction_policy="evict_last") + c = tl.load(c_ptrs, mask=mask, other=0.0, eviction_policy="evict_last") + + s, c = tl_2sum(s, c + v) + + tl.store(ptrs, s, mask=mask, eviction_policy="evict_last") + tl.store(c_ptrs, c, mask=mask, eviction_policy="evict_last") + + tl.debug_barrier() + tl.atomic_xchg(lock_ptr, 0) + + +@triton.jit +def tl_lock_add(ptrs, v, mask, lock_ptr): + while tl.atomic_cas(lock_ptr, 0, 1) == 1: + pass + + cur_v = tl.load(ptrs, mask=mask, other=0.0, eviction_policy="evict_last") + new_v = v + cur_v + tl.store(ptrs, new_v, mask=mask, eviction_policy="evict_last") + + tl.debug_barrier() + tl.atomic_xchg(lock_ptr, 0) + + +def b_bin_fn(b: int) -> int: + if b >= 1024: + return 1024 + elif b <= 128: + return 128 + else: + return 512 \ No newline at end of file diff --git a/kernels/cut_cross_entropy/torch_compile.py b/kernels/cut_cross_entropy/torch_compile.py new file mode 100644 index 000000000..a2e7d124b --- /dev/null +++ b/kernels/cut_cross_entropy/torch_compile.py @@ -0,0 +1,82 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import torch +import torch.nn.functional as F + +from cut_cross_entropy.constants import IGNORE_INDEX +from cut_cross_entropy.doc import LINEAR_CROSS_ENTROPY_DOC, add_doc_start +from cut_cross_entropy.utils import ( + _build_flat_valids, + handle_reduction_none, + softcapping, +) + + +@torch.compile(fullgraph=True, dynamic=True) +def torch_compile_linear_cross_entropy_apply( + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + softcap: float | None = None, + *, + ignore_index: int = IGNORE_INDEX, + reduction: str = "mean", +) -> torch.Tensor: + logits = e @ c.T + + if bias is not None: + logits = logits + bias + + if softcap is not None: + logits = softcapping(logits, softcap) + + loss = F.cross_entropy(logits.float(), targets, ignore_index=ignore_index, reduction=reduction) + + return loss + + +@add_doc_start(LINEAR_CROSS_ENTROPY_DOC) +def torch_compile_linear_cross_entropy( + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + ignore_index: int = IGNORE_INDEX, + softcap: float | None = None, + reduction: str = "mean", + shift: bool | int = 0, +) -> torch.Tensor: + assert e.size()[0:-1] == targets.size() + assert e.size(-1) == c.size(1) + + orig_b_size = targets.size() + e = e.contiguous() + targets = targets.contiguous() + + shift = int(shift) + valids = _build_flat_valids(targets, ignore_index, shift) + + e = e.flatten(0, -2) + targets = targets.flatten() + + if valids is not None: + e = e[valids] + targets = targets[(valids + shift) if shift != 0 else valids] + + loss = torch_compile_linear_cross_entropy_apply( + e, + c, + targets, + bias, + softcap, + ignore_index=ignore_index, + reduction=reduction, + ) + + if reduction == "none": + loss = handle_reduction_none(orig_b_size, valids, shift, loss) + + if shift != 0: + loss = loss[..., shift:] + + return loss \ No newline at end of file diff --git a/kernels/cut_cross_entropy/utils.py b/kernels/cut_cross_entropy/utils.py new file mode 100644 index 000000000..4e95e1a0e --- /dev/null +++ b/kernels/cut_cross_entropy/utils.py @@ -0,0 +1,55 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import torch + + +@torch.compile(fullgraph=True, dynamic=True) +def softcapping(logits: torch.Tensor, softcap: float) -> torch.Tensor: + return torch.tanh(logits / softcap) * softcap + + +def _handle_eps(filter_eps: float | str | None, dtype: torch.dtype) -> float | None: + match filter_eps: + case None: + return None + case float(): + return filter_eps + case "auto": + return torch.finfo(dtype).eps / 32 + case _: + raise RuntimeError(f"Unknown eps {filter_eps=}") + + +def _build_flat_valids( + targets: torch.Tensor, + ignore_index: int, + shift: int, +) -> torch.Tensor | None: + if shift != 0: + targets = targets[..., shift:] + else: + targets = targets.flatten() + + valids = (targets != ignore_index).nonzero().to(torch.int32) + + if shift == 0: + assert valids.size(1) == 1 + return valids.squeeze(1) if valids.numel() != targets.numel() else None + + for i in range(targets.ndim - 1): + valids[:, i] *= targets.stride(i) + + assert targets.stride(-1) == 1 + + return valids.sum(1) + + +def handle_reduction_none( + batch_shape: torch.Size, valids: torch.Tensor | None, shift: int, loss: torch.Tensor +) -> torch.Tensor: + if valids is None: + return loss.view(batch_shape) + + full_loss = loss.new_zeros((batch_shape.numel(),)) + full_loss[(valids + shift) if shift != 0 else valids] = loss + + return full_loss.view(batch_shape) \ No newline at end of file diff --git a/kernels/fused_linear_cross_entropy/__init__.py b/kernels/fused_linear_cross_entropy/__init__.py new file mode 100644 index 000000000..917dfc305 --- /dev/null +++ b/kernels/fused_linear_cross_entropy/__init__.py @@ -0,0 +1 @@ +from fused_linear_cross_entropy.fused_linear_ce_loss import LigerFusedLinearCrossEntropyFunction \ No newline at end of file diff --git a/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py b/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py new file mode 100644 index 000000000..3c1784f8d --- /dev/null +++ b/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py @@ -0,0 +1,543 @@ +#Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py +import torch +import triton +import triton.language as tl +try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh +except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh + + + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + weight_ptr, + loss_ptr, + z_loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + sum_non_ignore_weight, + weight_sum, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (flaot): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + if RETURN_Z_LOSS: + z_loss_ptr += program_id * loss_stride + + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr, z_loss) + + +def fused_linear_cross_entropy_forward( + _input, + weight, + target, + ce_weight=None, + bias=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss=False, + triton_backend=True +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + device = _input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + # chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + # num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + # inc_factor = (V + H - 1) // H + # chunk_size = (BT + inc_factor - 1) // inc_factor + # num_chunks = (BT + chunk_size - 1) // chunk_size + + chunk_size = 1024 + if triton_backend: + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_chunks = triton.cdiv(BT, chunk_size) + else: + num_chunks = (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point(ce_weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + ) + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + if triton_backend: + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=ce_weight, + loss_ptr=loss_1d_slice, + z_loss_ptr=z_loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + n_cols=V, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + HAS_WEIGHT=True if ce_weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + grad_logits_chunk = logits_chunk # chunk_size x V + else: + y_chunk = torch.nn.functional.softmax(logits_chunk, dim=1) + loss_1d_slice = -torch.log(y_chunk).gather(1, target_chunk.view(-1, 1)) + loss_1d_slice = loss_1d_slice.squeeze(1) + logits_chunk = y_chunk - torch.nn.functional.one_hot(target_chunk, num_classes=V) + logits_chunk = (logits_chunk * (chunk_size / BT)) + grad_logits_chunk = logits_chunk + + loss_1d[start_idx:end_idx] = loss_1d_slice + if return_z_loss: + z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=logits_chunk.t().to( + _input_chunk.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. + mat2=_input_chunk, + out=grad_weight, + alpha=1.0, + beta=1.0, + ) + + if bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + + else: + loss = torch.sum(loss_1d) if triton_backend else torch.mean(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + + return loss, z_loss, grad_input, grad_weight, grad_bias + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + # element_mul_kernel[(n_rows,)]( + # grad_input, + # grad_input.stride(-2), + # grad_output, + # H, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + # element_mul_kernel[(n_rows,)]( + # grad_weight, + # grad_weight.stride(-2), + # grad_output, + # H, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + # element_mul_kernel[(n_rows,)]( + # grad_bias, + # grad_bias.stride(-1), + # grad_output, + # 1, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + return grad_input, grad_weight, grad_bias + + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss: bool = False, + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + """ + + loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if bias is not None else None, + ) + ctx.return_z_loss = return_z_loss + # return loss, z_loss + return loss + + @staticmethod + # def backward(ctx, grad_output, grad_output2): + def backward(ctx, grad_output): + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/replay/models/nn/sequential/bert4rec/lightning.py b/replay/models/nn/sequential/bert4rec/lightning.py index bad557292..0d57807e7 100644 --- a/replay/models/nn/sequential/bert4rec/lightning.py +++ b/replay/models/nn/sequential/bert4rec/lightning.py @@ -10,6 +10,17 @@ from .dataset import Bert4RecPredictionBatch, Bert4RecTrainingBatch, Bert4RecValidationBatch, _shift_features from .model import Bert4RecModel, CatFeatureEmbedding +import sys +sys.path.append("./kernels") + +try: + from kernels.cut_cross_entropy.cce import CCEParams, LinearCrossEntropyFunction + from kernels.cut_cross_entropy.utils import ( + _build_flat_valids, + _handle_eps, + ) +except ModuleNotFoundError: + print("cut_cross_entropy is not installed. CCE / CCE_minus loss cannot be used.") class Bert4Rec(lightning.LightningModule): """ @@ -197,6 +208,8 @@ def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor: loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled elif self._loss_type == "CE": loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled + elif self._loss_type == "CCE": + loss_func = self._compute_loss_cce else: msg = f"Not supported loss type: {self._loss_type}" raise ValueError(msg) @@ -316,6 +329,108 @@ def _compute_loss_ce_sampled( loss = self._loss(logits, labels_flat) return loss + def _compute_loss_cce( + self, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + tokens_mask: torch.BoolTensor + ) -> torch.Tensor: + """ + Cut Cross-Entropy (CCE) and Cut Cross-Entropy with Negative Sampling (CCE-), + methods that computes the cross-entropy loss, + without materializing the logits for all tokens into global memory. + The method is implemented in custom Triton kernels. + + + Cut Cross Entropy for LLM is presented in + https://arxiv.org/abs/2411.09009 + https://github.com/apple/ml-cross-entropy + """ + + bias = None + ignore_index = -100 + softcap = None + reduction = "mean" + shift = False + filter_eps = "auto" + use_kahan = False + item_inds = None + + labels_mask = (~padding_mask) + tokens_mask + masked_tokens = ~labels_mask + + e = self._model.forward_step(feature_tensors, padding_mask, tokens_mask)[masked_tokens] + e = e.to(torch.float16) + targets = cast( + torch.LongTensor, torch.masked_select(positive_labels, masked_tokens) + ) + c = self._model._head.get_item_embeddings() + + e = e.contiguous() + padding_mask = padding_mask.contiguous() + + assert e.size()[0:-1] == targets.size() + assert e.size(-1) == c.size(1) + if not torch.cuda.is_bf16_supported(): + raise RuntimeError( + "Cut Cross Entropy requires an ampere GPU or newer. " + "Consider using torch_compile_linear_cross_entropy for scenarios where one is not available." + ) + + batch_shape = targets.size() + + shift = int(shift) + valids = _build_flat_valids(targets, ignore_index, shift) + + e = e.flatten(0, -2) + targets = targets.flatten() + + if (targets.data_ptr() % 16) != 0: + targets = torch.nn.functional.pad(targets, (0, 1))[:-1] + + assert (targets.data_ptr() % 16) == 0 + + if self._loss_sample_count is not None: + filter_eps = None + n_negative_samples = self._loss_sample_count + vocab_size = self._vocab_size + device = padding_mask.device + + masked_batch_seq_size = targets.size(0) + + if self._negative_sampling_strategy == "global_uniform": + negative_labels = torch.randint( + low=0, + high=vocab_size, + size=(masked_batch_seq_size, n_negative_samples), + dtype=torch.long, + device=device, + ) + + reject_labels_mask = targets.view(-1, 1) == negative_labels + negative_labels[reject_labels_mask] = vocab_size + + item_inds = torch.hstack([targets.view(-1, 1), negative_labels]) + + + params = CCEParams( + targets, + valids, + softcap, + reduction, + _handle_eps(filter_eps, e.dtype), + shift, + batch_shape, + use_kahan, + item_inds + ) + + loss = self._loss.apply(e, c.to(e.dtype), bias, params) + assert isinstance(loss, torch.Tensor) + + return loss + def _get_sampled_logits( self, feature_tensors: TensorMap, @@ -405,6 +520,9 @@ def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntrop if self._loss_type == "CE": return torch.nn.CrossEntropyLoss() + if self._loss_type == "CCE": + return LinearCrossEntropyFunction() + msg = "Not supported loss_type" raise NotImplementedError(msg) diff --git a/replay/models/nn/sequential/sasrec/lightning.py b/replay/models/nn/sequential/sasrec/lightning.py index f82461f20..074d0109c 100644 --- a/replay/models/nn/sequential/sasrec/lightning.py +++ b/replay/models/nn/sequential/sasrec/lightning.py @@ -10,6 +10,24 @@ from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidationBatch from .model import SasRecModel +import sys +sys.path.append("./kernels") + +try: + from kernels.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction + +except ModuleNotFoundError: + print("fused linear cross entropy is not installed. fused_linear_CE loss cannot be used.") + + +try: + from kernels.cut_cross_entropy.cce import CCEParams, LinearCrossEntropyFunction + from kernels.cut_cross_entropy.utils import ( + _build_flat_valids, + _handle_eps + ) +except ModuleNotFoundError: + print("cut_cross_entropy is not installed. CCE / CCE_minus loss cannot be used.") class SasRec(lightning.LightningModule): """ @@ -197,6 +215,8 @@ def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor: loss_func = self._compute_loss_bce if self._loss_sample_count is None else self._compute_loss_bce_sampled elif self._loss_type == "CE": loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled + elif self._loss_type == "CCE": + loss_func = self._compute_loss_cce else: msg = f"Not supported loss type: {self._loss_type}" raise ValueError(msg) @@ -314,6 +334,107 @@ def _compute_loss_ce_sampled( loss = self._loss(logits, labels_flat) return loss + def _compute_loss_cce( + self, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor + ) -> torch.Tensor: + """ + Cut Cross-Entropy (CCE) and Cut Cross-Entropy with Negative Sampling (CCE-), + methods that computes the cross-entropy loss, + without materializing the logits for all tokens into global memory. + The method is implemented in custom Triton kernels. + + + Cut Cross Entropy for LLM is presented in + https://arxiv.org/abs/2411.09009 + https://github.com/apple/ml-cross-entropy + """ + + bias = None + ignore_index = -100 + softcap = None + reduction = "mean" + shift = False + filter_eps = "auto" + use_kahan = False + item_inds = None + + e = self._model.forward_step(feature_tensors, padding_mask) + e = e.to(torch.float16) + targets = cast(torch.LongTensor, positive_labels) + c = self._model._head._item_embedder.get_all_item_weights() + + if self._loss_sample_count is not None: + targets = targets[target_padding_mask] + e = e[target_padding_mask] + + e = e.contiguous() + padding_mask = padding_mask.contiguous() + + assert e.size()[0:-1] == targets.size() + assert e.size(-1) == c.size(1) + if not torch.cuda.is_bf16_supported(): + raise RuntimeError( + "Cut Cross Entropy requires an ampere GPU or newer. " + "Consider using torch_compile_linear_cross_entropy for scenarios where one is not available." + ) + + batch_shape = targets.size() + + shift = int(shift) + valids = _build_flat_valids(targets, ignore_index, shift) + + e = e.flatten(0, -2) + targets = targets.flatten() + + if (targets.data_ptr() % 16) != 0: + targets = torch.nn.functional.pad(targets, (0, 1))[:-1] + + assert (targets.data_ptr() % 16) == 0 + + if self._loss_sample_count is not None: + filter_eps = None + n_negative_samples = self._loss_sample_count + vocab_size = self._vocab_size + device = padding_mask.device + + masked_batch_seq_size = targets.size(0) + + if self._negative_sampling_strategy == "global_uniform": + negative_labels = torch.randint( + low=0, + high=vocab_size, + size=(masked_batch_seq_size, n_negative_samples), + dtype=torch.long, + device=device, + ) + + reject_labels_mask = targets.view(-1, 1) == negative_labels + negative_labels[reject_labels_mask] = vocab_size + + item_inds = torch.hstack([targets.view(-1, 1), negative_labels]) + + + params = CCEParams( + targets, + valids, + softcap, + reduction, + _handle_eps(filter_eps, e.dtype), + shift, + batch_shape, + use_kahan, + item_inds + ) + + loss = self._loss.apply(e, c.to(e.dtype), bias, params) + assert isinstance(loss, torch.Tensor) + + return loss + def _get_sampled_logits( self, feature_tensors: TensorMap, @@ -401,6 +522,9 @@ def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntrop if self._loss_type == "CE": return torch.nn.CrossEntropyLoss() + if self._loss_type == "CCE": + return LinearCrossEntropyFunction() + msg = "Not supported loss_type" raise NotImplementedError(msg) From 4e85442f21d5d3192195dbaf5117bd9df6e27a7f Mon Sep 17 00:00:00 2001 From: Dmitry Date: Mon, 28 Apr 2025 11:22:26 +0000 Subject: [PATCH 2/2] del unused --- .../fused_linear_cross_entropy/__init__.py | 1 - .../fused_linear_ce_loss.py | 543 ------------------ 2 files changed, 544 deletions(-) delete mode 100644 kernels/fused_linear_cross_entropy/__init__.py delete mode 100644 kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py diff --git a/kernels/fused_linear_cross_entropy/__init__.py b/kernels/fused_linear_cross_entropy/__init__.py deleted file mode 100644 index 917dfc305..000000000 --- a/kernels/fused_linear_cross_entropy/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from fused_linear_cross_entropy.fused_linear_ce_loss import LigerFusedLinearCrossEntropyFunction \ No newline at end of file diff --git a/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py b/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py deleted file mode 100644 index 3c1784f8d..000000000 --- a/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py +++ /dev/null @@ -1,543 +0,0 @@ -#Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py -import torch -import triton -import triton.language as tl -try: - # typical import path with dispatch available - from triton.language.extra.libdevice import tanh -except ModuleNotFoundError: - # for working with NGC containers - from triton.language.extra.cuda.libdevice import tanh - - - - -# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 -# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling -# The optimal maximum block size depends on your hardware, your kernel, and your dtype -MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning - - -@triton.jit -def liger_cross_entropy_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - weight_ptr, - loss_ptr, - z_loss_ptr, - loss_stride, - n_cols, - n_non_ignore, - sum_non_ignore_weight, - weight_sum, - ignore_index, - lse_square_scale: tl.constexpr, - label_smoothing: tl.constexpr, - reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time - softcap, - RETURN_Z_LOSS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - HAS_WEIGHT: tl.constexpr, - HAS_SOFTCAPPING: tl.constexpr, -): - """ - This kernel computes both cross entropy loss and the gradient of the input. - We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - weight_ptr: Pointer to weight tensor. - loss_ptr: Pointer to tensor to store the loss. - z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. - loss_stride (int): The stride of the loss tensor. - n_cols (int): The number of columns in the input tensor. - n_non_ignore (flaot): The number of non-ignored elements in the batch. - sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. - weight_sum (float): The sum of weight tensor. - ignore_index (int): The index to ignore in the target. - label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. - lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. - reduction (str): The string for the reduction to apply - softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). - RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. - BLOCK_SIZE (int): The block size for Triton operations. - HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. - HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. - """ - - # https://github.com/triton-lang/triton/issues/1058 - # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 - program_id = tl.program_id(0).to(tl.int64) - - # 1. Load Y_ptr first because if the target is ignore_index, we can return right away - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - # 2. locate the start index - X_ptr += program_id * X_stride - - if y == ignore_index: - # set all X_ptr as 0 - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) - return - - loss_ptr += program_id * loss_stride - if RETURN_Z_LOSS: - z_loss_ptr += program_id * loss_stride - - if HAS_WEIGHT: - weight_y = tl.load(weight_ptr + y).cast(tl.float32) - - # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) - # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 - - # 3. [Online softmax] first pass: find max + sum - m = float("-inf") # m is the max value. use the notation from the paper - d = 0.0 # d is the sum. use the notation from the paper - ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation - if HAS_SOFTCAPPING: - ori_X_y = softcap * tanh(ori_X_y / softcap) - - # Label smoothing is a general case of normal cross entropy - # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 - scaled_x_sum = 0.0 - eps = label_smoothing / n_cols - - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load( - X_ptr + X_offsets, - mask=X_offsets < n_cols, - other=float("-inf"), - # Ensure float32 precision for softmax calculation - ).cast(tl.float32) - if HAS_SOFTCAPPING: - X_block = softcap * tanh(X_block / softcap) - block_max = tl.max(X_block) - if label_smoothing > 0: - # scale X beforehand to avoid overflow - if HAS_WEIGHT: - weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) - else: - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - m_new = tl.maximum(m, block_max) - d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) - m = m_new - - # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) - # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) - # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d - lse = m + tl.log(d) - - # 4. [Online Softmax] Second pass: compute gradients - # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) - # dx_y = (softmax(x_y) - 1) / N - # dx_i = softmax(x_i) / N, i != y - # For label smoothing: - # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y - # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N - # = dx_i - (1 - label_smoothing) / N - # With Z loss: - # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y - # dx_y = dx_i - (1 - label_smoothing) / N - # For 'sum' reduction, no normalization is applied: - # dx_y = softmax(x_y) - 1 - # dx_i = softmax(x_i), for i ≠ y - - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load( - X_ptr + X_offsets, - mask=X_offsets < n_cols, - other=float("-inf"), - # Ensure float32 precision for softmax calculation - ).cast(tl.float32) - if HAS_SOFTCAPPING: - intermediate = tanh(X_block / softcap) - X_block = softcap * intermediate - - if not HAS_WEIGHT: - # softmax(x_i) - X_block = tl.exp(X_block - m) / d - # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) - X_block += 2 * lse_square_scale * lse * X_block - # smoothing term - X_block += -eps - # special handle dx_y - X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) - # reduction scale - if reduction == "mean": - X_block = X_block / n_non_ignore - else: - weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) - softmax_X = tl.exp(X_block - m) / d - # derivative of original_loss - dloss_ori = (1 - label_smoothing) * softmax_X - # specially handle dx_y - dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) - dloss_ori = dloss_ori * weight_y - # derivative of smooth_loss - dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) - # derivative of z-loss - dz_loss = 2 * lse_square_scale * lse * softmax_X - # reduction scale - if reduction == "mean": - dloss_ori = dloss_ori / sum_non_ignore_weight - dloss_smooth = dloss_smooth / sum_non_ignore_weight - # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. - dz_loss = dz_loss / n_non_ignore - # derivative of total_loss - X_block = dloss_ori + dloss_smooth + dz_loss - - # chain rule softcapping - # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) - if HAS_SOFTCAPPING: - X_block = X_block * (1 - intermediate * intermediate) - - tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) - - # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in - # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 - tl.debug_barrier() - - # 5. Calculate the loss - - # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) - # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) - # = X_y - m - log d = X_y - lse - # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 - # So we can safely calculate log (softmax(X_y)) without overflow - loss = lse - ori_X_y - if HAS_WEIGHT: - loss = weight_y * loss - - # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps - # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) - # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) - # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: - # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) - # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 - # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 - # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 - if label_smoothing > 0: - if HAS_WEIGHT: - smooth_loss = scaled_x_sum + eps * lse * weight_sum - else: - smooth_loss = scaled_x_sum + label_smoothing * lse - loss = loss * (1 - label_smoothing) + smooth_loss - - # An auxiliary loss, z_loss - # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html - z_loss = lse_square_scale * lse * lse - # Normalize the loss by the number of non-ignored elements if reduction is "mean" - if reduction == "mean": - if HAS_WEIGHT: - loss = loss / sum_non_ignore_weight - else: - loss = loss / n_non_ignore - # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. - z_loss = z_loss / n_non_ignore - loss += z_loss - - tl.store(loss_ptr, loss) - if RETURN_Z_LOSS: - tl.store(z_loss_ptr, z_loss) - - -def fused_linear_cross_entropy_forward( - _input, - weight, - target, - ce_weight=None, - bias=None, - ignore_index=-100, - lse_square_scale=0.0, - label_smoothing=0.0, - reduction="mean", - softcap=None, - return_z_loss=False, - triton_backend=True -): - assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" - device = _input.device - - # inputs have shape: BT x H - # materialized activations will have shape: BT x V - # the increase in memory = BT x V - # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. - # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: - # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor - # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 - BT, H = _input.shape - V = weight.shape[0] - # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - - # inc_factor = triton.cdiv(V, H) # (V + H - 1) // H - # chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor - # num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size - - # inc_factor = (V + H - 1) // H - # chunk_size = (BT + inc_factor - 1) // inc_factor - # num_chunks = (BT + chunk_size - 1) // chunk_size - - chunk_size = 1024 - if triton_backend: - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - num_chunks = triton.cdiv(BT, chunk_size) - else: - num_chunks = (BT + chunk_size - 1) // chunk_size - - grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None - grad_input = torch.zeros_like(_input, device=device) - grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None - # we use fp32 for loss accumulator - loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) - z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None - - # TODO: evaluate how CUDA synchronization caused by .item() affects the speed - target_mask = target != ignore_index - total_n_non_ignore = target_mask.sum().item() - total_sum_non_ignore_ce_weight = total_n_non_ignore - ce_weight_sum = 0.0 - if ce_weight is not None: - assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" - assert torch.is_floating_point(ce_weight), ( - f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" - ) - total_sum_non_ignore_ce_weight = ( - torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() - ) - ce_weight_sum = ce_weight.sum().item() - if ce_weight.stride(-1) != 1: - ce_weight = ce_weight.contiguous() - - for chunk_id in range(num_chunks): - start_idx = chunk_id * chunk_size - end_idx = min((chunk_id + 1) * chunk_size, BT) - _input_chunk = _input[start_idx:end_idx] # chunk_size x H - - # when doing matmul, use the original precision - logits_chunk = _input_chunk @ weight.t() # chunk_size x V - if bias is not None: - logits_chunk = logits_chunk + bias - - target_chunk = target[start_idx:end_idx] # chunk_size, - - n_rows = logits_chunk.shape[0] - - - - # ensure _input and target are contiguous - logits_chunk = logits_chunk.contiguous() - target_chunk = target_chunk.contiguous() - - if triton_backend: - # unreduced loss - loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, - z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None - - # Here we calculate the gradient of logits_chunk in place so we can save memory. - liger_cross_entropy_kernel[(n_rows,)]( - X_ptr=logits_chunk, - X_stride=logits_chunk.stride(-2), - Y_ptr=target_chunk, - Y_stride=target_chunk.stride(-1), # always 1 - weight_ptr=ce_weight, - loss_ptr=loss_1d_slice, - z_loss_ptr=z_loss_1d_slice, - loss_stride=loss_1d_slice.stride(-1), # always 1 - n_cols=V, - n_non_ignore=total_n_non_ignore, - sum_non_ignore_weight=total_sum_non_ignore_ce_weight, - weight_sum=ce_weight_sum, - ignore_index=ignore_index, - lse_square_scale=lse_square_scale, - label_smoothing=label_smoothing, - reduction=reduction, - softcap=softcap, - RETURN_Z_LOSS=return_z_loss, - HAS_WEIGHT=True if ce_weight is not None else False, - HAS_SOFTCAPPING=True if softcap is not None else False, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - grad_logits_chunk = logits_chunk # chunk_size x V - else: - y_chunk = torch.nn.functional.softmax(logits_chunk, dim=1) - loss_1d_slice = -torch.log(y_chunk).gather(1, target_chunk.view(-1, 1)) - loss_1d_slice = loss_1d_slice.squeeze(1) - logits_chunk = y_chunk - torch.nn.functional.one_hot(target_chunk, num_classes=V) - logits_chunk = (logits_chunk * (chunk_size / BT)) - grad_logits_chunk = logits_chunk - - loss_1d[start_idx:end_idx] = loss_1d_slice - if return_z_loss: - z_loss_1d[start_idx:end_idx] = z_loss_1d_slice - - grad_input[start_idx:end_idx] = grad_logits_chunk @ weight - - if grad_weight is not None: - torch.addmm( - input=grad_weight, - mat1=logits_chunk.t().to( - _input_chunk.dtype - ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. - mat2=_input_chunk, - out=grad_weight, - alpha=1.0, - beta=1.0, - ) - - if bias is not None: - torch.add( - input=grad_bias, - other=logits_chunk.sum(dim=0), - out=grad_bias, - alpha=1.0, - ) - - if reduction == "none": - loss = loss_1d - z_loss = z_loss_1d if return_z_loss else None - - else: - loss = torch.sum(loss_1d) if triton_backend else torch.mean(loss_1d) - z_loss = torch.sum(z_loss_1d) if return_z_loss else None - - return loss, z_loss, grad_input, grad_weight, grad_bias - -def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): - # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): - # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place - # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. - BT, H = grad_input.shape - n_rows = BT - # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) - - # element_mul_kernel[(n_rows,)]( - # grad_input, - # grad_input.stride(-2), - # grad_output, - # H, - # BLOCK_SIZE=BLOCK_SIZE, - # num_warps=32 if not is_hip() else 16, - # ) - - # handle grad_weight - if grad_weight is not None: - V, H = grad_weight.shape - n_rows = V - - # element_mul_kernel[(n_rows,)]( - # grad_weight, - # grad_weight.stride(-2), - # grad_output, - # H, - # BLOCK_SIZE=BLOCK_SIZE, - # num_warps=32 if not is_hip() else 16, - # ) - - if grad_bias is not None: - V = grad_bias.shape[0] - n_rows = V - - # element_mul_kernel[(n_rows,)]( - # grad_bias, - # grad_bias.stride(-1), - # grad_output, - # 1, - # BLOCK_SIZE=BLOCK_SIZE, - # num_warps=32 if not is_hip() else 16, - # ) - return grad_input, grad_weight, grad_bias - - -class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - _input, - weight, - target, - bias=None, - ce_weight=None, - ignore_index=-100, - lse_square_scale=0.0, - label_smoothing=0.0, - reduction="mean", - softcap=None, - return_z_loss: bool = False, - ): - """ - Fusing the last linear layer with cross-entropy loss - Reference: https://github.com/mgmalek/efficient_cross_entropy - - Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding - the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can - compute the gradient at the forward pass. By doing so, we don't have to store the _input and target - for the backward pass. - - _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. - target: (B*T) where each value is in [0, V-1] - weight: (V, H) where V is the number of classes - bias: (V) where V is the number of classes - ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype - ignore_index: the index to ignore in the target - label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. - reduction: reduction to apply - """ - - loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input=_input, - weight=weight, - target=target, - bias=bias, - ce_weight=ce_weight, - ignore_index=ignore_index, - lse_square_scale=lse_square_scale, - label_smoothing=label_smoothing, - reduction=reduction, - softcap=softcap, - return_z_loss=return_z_loss, - ) - # downcast to dtype and store for backward - ctx.save_for_backward( - grad_input.detach(), - grad_weight.detach() if grad_weight is not None else None, - grad_bias.detach() if bias is not None else None, - ) - ctx.return_z_loss = return_z_loss - # return loss, z_loss - return loss - - @staticmethod - # def backward(ctx, grad_output, grad_output2): - def backward(ctx, grad_output): - if ctx.return_z_loss: - del grad_output2 # z_loss is only for logging - (grad_input, grad_weight, grad_bias) = ctx.saved_tensors - grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( - grad_output, grad_input, grad_weight, grad_bias - ) - return ( - grad_input, - grad_weight, - None, - grad_bias, - None, - None, - None, - None, - None, - None, - None, - ) \ No newline at end of file