From 8c5e5828ef06086f785a632f83bfb10b15f8e52a Mon Sep 17 00:00:00 2001 From: "mejai.p" Date: Wed, 20 Sep 2023 23:24:38 +0900 Subject: [PATCH] Optimize Linear and GEGLU --- tests/test_geglu.py | 1 - tests/test_linear.py | 3 +- trident/function/function.py | 12 ++----- trident/kernel/geglu.py | 66 ++++++++++++++++++++-------------- trident/kernel/linear.py | 41 +++++++++------------- trident/operation/geglu.py | 68 +++++++++++++++++++++--------------- trident/operation/linear.py | 67 ++++++++++++++++++++--------------- 7 files changed, 137 insertions(+), 121 deletions(-) diff --git a/tests/test_geglu.py b/tests/test_geglu.py index 0232abc2..60c0f17c 100644 --- a/tests/test_geglu.py +++ b/tests/test_geglu.py @@ -58,7 +58,6 @@ def train(func): assert util.equal(y, b, 3e-01) input = input.permute(0, 2, 1).reshape(num_batches, m_size, k_size) - weight = weight.permute(1, 0).reshape(n_size, k_size) (x, y) = train(geglu) (a, b) = train(trident.function.geglu) diff --git a/tests/test_linear.py b/tests/test_linear.py index fa99d612..7db5e297 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -31,7 +31,7 @@ def test_forward(num_batches, m_size, n_size, k_size, device): assert util.equal(torch.nn.functional.linear(input, weight, bias), trident.function.linear(input, weight, bias)) input = input.permute(0, 2, 1) - weight = weight.permute(1, 0) + weight = torch.randn(k_size, n_size, device=device) assert util.equal(torch.nn.functional.linear(input, weight), trident.function.linear(input, weight)) @@ -56,7 +56,6 @@ def train(func): assert util.equal(y, b) input = input.permute(0, 2, 1).reshape(num_batches, m_size, k_size) - weight = weight.permute(1, 0).reshape(n_size, k_size) (x, y) = train(torch.nn.functional.linear) (a, b) = train(trident.function.linear) diff --git a/trident/function/function.py b/trident/function/function.py index 6a265f79..59de4085 100644 --- a/trident/function/function.py +++ b/trident/function/function.py @@ -73,11 +73,7 @@ def geglu(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None, See GEGLU for details. """ - if input.dim() == 2: - output = operation.GEGLU.apply(input.view(1, *input.shape), weight, bias, use_accelerator) - return output.view(output.shape[1:3]) - else: - return operation.GEGLU.apply(input, weight, bias, use_accelerator) + return operation.GEGLU.apply(input, weight, bias, use_accelerator) def gelu(input: torch.Tensor): @@ -156,11 +152,7 @@ def linear( See Linear for more details. """ - if input.dim() == 2: - output = operation.Linear.apply(input.view(1, *input.shape), weight, bias, use_accelerator) - return output.view(output.shape[1:3]) - else: - return operation.Linear.apply(input, weight, bias, use_accelerator) + return operation.Linear.apply(input, weight, bias, use_accelerator) def max(input: torch.Tensor, dim: int): diff --git a/trident/kernel/geglu.py b/trident/kernel/geglu.py index ba82a1d6..9acaba56 100644 --- a/trident/kernel/geglu.py +++ b/trident/kernel/geglu.py @@ -55,9 +55,9 @@ class GEGLU: @util.autotune(geglu_configs(), ["m_size", "k_size", "x_size"]) @triton.heuristics( { - "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0, - "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0, - "require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"] == 0, + "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"], + "require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"], + "require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"], } ) @triton.jit @@ -71,7 +71,6 @@ def forward( n_size: tl.int32, k_size: tl.int32, x_size: tl.int32, - input_batch_stride: tl.int32, input_m_stride: tl.int32, input_k_stride: tl.int32, weight_n_stride: tl.int32, @@ -89,7 +88,6 @@ def forward( num_m_blocks = tl.cdiv(m_size, m_block_size) num_x_blocks = tl.cdiv(x_size, x_block_size) num_blocks = num_m_blocks * num_x_blocks - batch = pid // num_blocks block = pid % num_blocks m_block = block // num_x_blocks x_block = block % num_x_blocks @@ -97,7 +95,7 @@ def forward( x_offset = x_block * x_block_size output_block_ptr = tl.make_block_ptr( - output_ptr + batch * m_size * x_size, + output_ptr, shape=(m_size, x_size), strides=(x_size, 1), offsets=(m_offset, x_offset), @@ -105,7 +103,7 @@ def forward( order=(1, 0), ) state_block_ptr = tl.make_block_ptr( - state_gate_ptr + batch * m_size * n_size, + state_gate_ptr, shape=(m_size, n_size), strides=(n_size, 1), offsets=(m_offset, x_offset), @@ -113,7 +111,7 @@ def forward( order=(1, 0), ) gate_block_ptr = tl.make_block_ptr( - state_gate_ptr + batch * m_size * n_size, + state_gate_ptr, shape=(m_size, n_size), strides=(n_size, 1), offsets=(m_offset, x_offset + x_size), @@ -122,7 +120,7 @@ def forward( ) state = language.Linear.forward( - input_ptr + batch * input_batch_stride, + input_ptr, weight_ptr, bias_ptr, m_size, @@ -144,7 +142,7 @@ def forward( dtype, ) gate = language.Linear.forward( - input_ptr + batch * input_batch_stride, + input_ptr, weight_ptr, bias_ptr, m_size, @@ -167,16 +165,21 @@ def forward( ) output = state * language.math.GELU.forward(gate) - if require_m_boundary_check & require_x_boundary_check: - tl.store(output_block_ptr, output.to(dtype)) - tl.store(state_block_ptr, state.to(dtype)) - tl.store(gate_block_ptr, gate.to(dtype)) - else: + if require_m_boundary_check | require_x_boundary_check: tl.store(output_block_ptr, output.to(dtype), boundary_check=(0, 1)) tl.store(state_block_ptr, state.to(dtype), boundary_check=(0, 1)) tl.store(gate_block_ptr, gate.to(dtype), boundary_check=(0, 1)) + else: + tl.store(output_block_ptr, output.to(dtype)) + tl.store(state_block_ptr, state.to(dtype)) + tl.store(gate_block_ptr, gate.to(dtype)) @staticmethod + @triton.heuristics( + { + "require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"], + } + ) @triton.jit def backward( grad_state_gate_ptr: tl.tensor, @@ -187,13 +190,13 @@ def backward( x_size: tl.int32, dtype: tl.constexpr, x_block_size: tl.constexpr, + require_x_boundary_check: tl.constexpr, ): pid = tl.program_id(0) - batch = pid // m_size m_offset = pid % m_size grad_state_block_ptr = tl.make_block_ptr( - grad_state_gate_ptr + batch * m_size * n_size, + grad_state_gate_ptr, shape=(m_size, n_size), strides=(n_size, 1), offsets=(m_offset, 0), @@ -201,7 +204,7 @@ def backward( order=(1, 0), ) grad_gate_block_ptr = tl.make_block_ptr( - grad_state_gate_ptr + batch * m_size * n_size, + grad_state_gate_ptr, shape=(m_size, n_size), strides=(n_size, 1), offsets=(m_offset, x_size), @@ -209,7 +212,7 @@ def backward( order=(1, 0), ) grad_output_block_ptr = tl.make_block_ptr( - grad_output_ptr + batch * m_size * x_size, + grad_output_ptr, shape=(m_size, x_size), strides=(x_size, 1), offsets=(m_offset, 0), @@ -217,7 +220,7 @@ def backward( order=(1, 0), ) state_block_ptr = tl.make_block_ptr( - state_gate_ptr + batch * m_size * n_size, + state_gate_ptr, shape=(m_size, n_size), strides=(n_size, 1), offsets=(m_offset, 0), @@ -225,7 +228,7 @@ def backward( order=(1, 0), ) gate_block_ptr = tl.make_block_ptr( - state_gate_ptr + batch * m_size * n_size, + state_gate_ptr, shape=(m_size, n_size), strides=(n_size, 1), offsets=(m_offset, x_size), @@ -233,10 +236,21 @@ def backward( order=(1, 0), ) - grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,)) - state = tl.load(state_block_ptr, boundary_check=(1,)) - gate = tl.load(gate_block_ptr, boundary_check=(1,)) + if require_x_boundary_check: + grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,)) + state = tl.load(state_block_ptr, boundary_check=(1,)) + gate = tl.load(gate_block_ptr, boundary_check=(1,)) + else: + grad_output = tl.load(grad_output_block_ptr) + state = tl.load(state_block_ptr) + gate = tl.load(gate_block_ptr) + grad_state = grad_output * language.math.GELU.forward(gate) grad_gate = language.math.GELU.backward(grad_output * state, gate) - tl.store(grad_state_block_ptr, grad_state.to(dtype), boundary_check=(1,)) - tl.store(grad_gate_block_ptr, grad_gate.to(dtype), boundary_check=(1,)) + + if require_x_boundary_check: + tl.store(grad_state_block_ptr, grad_state.to(dtype), boundary_check=(1,)) + tl.store(grad_gate_block_ptr, grad_gate.to(dtype), boundary_check=(1,)) + else: + tl.store(grad_state_block_ptr, grad_state.to(dtype)) + tl.store(grad_gate_block_ptr, grad_gate.to(dtype)) diff --git a/trident/kernel/linear.py b/trident/kernel/linear.py index 39ea3a79..c722f22f 100644 --- a/trident/kernel/linear.py +++ b/trident/kernel/linear.py @@ -111,7 +111,7 @@ def linear_configs_for_backward_bias(): class Linear: @staticmethod - @util.autotune(linear_configs([16, 64, 128], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"]) + @util.autotune(linear_configs([16, 64, 128, 256], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"]) @triton.heuristics( { "require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"], @@ -128,7 +128,6 @@ def forward( m_size: tl.int32, n_size: tl.int32, k_size: tl.int32, - input_batch_stride: tl.int32, input_m_stride: tl.int32, input_k_stride: tl.int32, weight_n_stride: tl.int32, @@ -143,18 +142,14 @@ def forward( require_k_boundary_check: tl.constexpr, ): pid = tl.program_id(0) - num_m_blocks = tl.cdiv(m_size, m_block_size) num_n_blocks = tl.cdiv(n_size, n_block_size) - num_blocks = num_m_blocks * num_n_blocks - batch = pid // num_blocks - block = pid % num_blocks - m_block = block // num_n_blocks - n_block = block % num_n_blocks + m_block = pid // num_n_blocks + n_block = pid % num_n_blocks m_offset = m_block * m_block_size n_offset = n_block * n_block_size output = language.Linear.forward( - input_ptr + batch * input_batch_stride, + input_ptr, weight_ptr, bias_ptr, m_size, @@ -177,7 +172,7 @@ def forward( ) output_block_ptr = tl.make_block_ptr( - output_ptr + batch * m_size * n_size, + output_ptr, shape=(m_size, n_size), strides=(n_size, 1), offsets=(m_offset, n_offset), @@ -223,7 +218,6 @@ def backward( num_m_blocks = tl.cdiv(m_size, m_block_size) num_k_blocks = tl.cdiv(k_size, k_block_size) num_blocks = num_m_blocks * num_k_blocks - batch = pid // num_blocks block = pid % num_blocks m_block = block // num_k_blocks k_block = block % num_k_blocks @@ -231,7 +225,7 @@ def backward( k_offset = k_block * k_block_size grad_input = language.Linear.backward( - grad_output_ptr + batch * m_size * n_size, + grad_output_ptr, weight_ptr, m_size, n_size, @@ -251,7 +245,7 @@ def backward( ) grad_input_block_ptr = tl.make_block_ptr( - grad_input_ptr + batch * m_size * k_size, + grad_input_ptr, shape=(m_size, k_size), strides=(input_m_stride, input_k_stride), offsets=(m_offset, k_offset), @@ -278,13 +272,12 @@ def backward( ) @triton.jit def backward_weight( - grad_weight_staging_ptr: tl.tensor, + grad_weight_ptr: tl.tensor, grad_output_ptr: tl.tensor, input_ptr: tl.tensor, m_size: tl.int32, n_size: tl.int32, k_size: tl.int32, - input_batch_stride: tl.int32, input_m_stride: tl.int32, input_k_stride: tl.int32, use_accelerator: tl.constexpr, @@ -300,7 +293,6 @@ def backward_weight( num_n_blocks = tl.cdiv(n_size, n_block_size) num_k_blocks = tl.cdiv(k_size, k_block_size) num_blocks = num_n_blocks * num_k_blocks - batch = pid // num_blocks block = pid % num_blocks n_block = block // num_k_blocks k_block = block % num_k_blocks @@ -308,8 +300,8 @@ def backward_weight( k_offset = k_block * k_block_size grad_weight = language.Linear.backward_weight( - grad_output_ptr + batch * m_size * n_size, - input_ptr + batch * input_batch_stride, + grad_output_ptr, + input_ptr, m_size, n_size, k_size, @@ -328,7 +320,7 @@ def backward_weight( ) grad_weight_staging_block_ptr = tl.make_block_ptr( - grad_weight_staging_ptr + batch * n_size * k_size, + grad_weight_ptr, shape=(n_size, k_size), strides=(k_size, 1), offsets=(n_offset, k_offset), @@ -346,7 +338,7 @@ def backward_weight( @triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"]}) @triton.jit def backward_bias( - grad_bias_staging_ptr: tl.tensor, + grad_bias_ptr: tl.tensor, grad_output_ptr: tl.tensor, m_size: tl.int32, n_size: tl.int32, @@ -355,10 +347,9 @@ def backward_bias( require_m_boundary_check: tl.constexpr, ): pid = tl.program_id(0) - batch = pid // n_size n_offset = pid % n_size grad_bias = language.Linear.backward_bias( - grad_output_ptr + batch * m_size * n_size, + grad_output_ptr, m_size, n_size, n_offset, @@ -367,12 +358,12 @@ def backward_bias( dtype, ) - grad_bias_staging_block_ptr = tl.make_block_ptr( - grad_bias_staging_ptr + batch * n_size, + grad_bias_block_ptr = tl.make_block_ptr( + grad_bias_ptr, shape=(n_size,), strides=(1,), offsets=(n_offset,), block_shape=(1,), order=(0,), ) - tl.store(grad_bias_staging_block_ptr, grad_bias) + tl.store(grad_bias_block_ptr, grad_bias) diff --git a/trident/operation/geglu.py b/trident/operation/geglu.py index b7cae541..419586aa 100644 --- a/trident/operation/geglu.py +++ b/trident/operation/geglu.py @@ -26,13 +26,19 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): input, weight, bias, use_accelerator = args util.push_trace("GEGLU.__forward") - output, state_gate = GEGLU.__forward(input, weight, bias, use_accelerator) + input_shape = GEGLU.__input_shape(input) + output, state_gate = GEGLU.__forward( + input.view(input_shape) if input.is_contiguous() else input.reshape(input_shape), + weight, + bias, + use_accelerator, + ) util.pop_trace() ctx.save_for_backward(input, weight, bias, state_gate) ctx.use_accelerator = False - return output + return output.view(GEGLU.__output_shape(input, weight)) @staticmethod def backward(ctx: Any, *grad_outputs: Any): @@ -41,25 +47,25 @@ def backward(ctx: Any, *grad_outputs: Any): util.push_trace("GEGLU.__backward") grad_input, grad_weight, grad_bias = GEGLU.__backward( - grad_output, input, weight, bias, state_gate, ctx.use_accelerator + grad_output, input.view(GEGLU.__input_shape(input)), weight, bias, state_gate, ctx.use_accelerator ) util.pop_trace() - return grad_input, grad_weight, grad_bias, None + return grad_input.view(input.shape), grad_weight, grad_bias, None @staticmethod def __forward(input, weight, bias, use_accelerator): factory_kwargs = {"device": input.device, "dtype": input.dtype} - num_batches, m_size, k_size = input.shape + m_size, k_size = input.shape n_size, _ = weight.shape x_size = n_size // 2 - output = torch.empty(num_batches, m_size, x_size, **factory_kwargs) - state_gate = torch.empty(num_batches, m_size, n_size, **factory_kwargs) + output = torch.empty(m_size, x_size, **factory_kwargs) + state_gate = torch.empty(m_size, n_size, **factory_kwargs) def grid(meta): num_m_blocks = triton.cdiv(m_size, meta["m_block_size"]) num_x_blocks = triton.cdiv(x_size, meta["x_block_size"]) - return (num_batches * num_m_blocks * num_x_blocks,) + return (num_m_blocks * num_x_blocks,) util.push_trace("kernel.GEGLU.forward") kernel.GEGLU.forward[grid]( @@ -74,7 +80,6 @@ def grid(meta): x_size, input.stride(0), input.stride(1), - input.stride(2), weight.stride(0), weight.stride(1), use_accelerator, @@ -87,15 +92,15 @@ def grid(meta): @staticmethod def __backward(grad_output, input, weight, bias, state_gate, use_accelerator): factory_kwargs = {"device": input.device, "dtype": input.dtype} - num_batches, m_size, k_size = input.shape + m_size, k_size = input.shape n_size, _ = weight.shape x_size = n_size // 2 grad_state_gate = torch.empty_like(state_gate) grad_input = torch.empty_like(input) - grad_weight_staging = torch.empty(num_batches, n_size, k_size, **factory_kwargs) + grad_weight = torch.empty(n_size, k_size, **factory_kwargs) def grid(meta): - return (num_batches * m_size,) + return (m_size,) util.push_trace("kernel.GEGLU.backward") kernel.GEGLU.backward[grid]( @@ -113,7 +118,7 @@ def grid(meta): def grid(meta): num_m_blocks = triton.cdiv(m_size, meta["m_block_size"]) num_k_blocks = triton.cdiv(k_size, meta["k_block_size"]) - return (num_batches * num_m_blocks * num_k_blocks,) + return (num_m_blocks * num_k_blocks,) util.push_trace("kernel.Linear.backward") kernel.Linear.backward[grid]( @@ -123,8 +128,8 @@ def grid(meta): m_size, n_size, k_size, + input.stride(0), input.stride(1), - input.stride(2), weight.stride(0), weight.stride(1), use_accelerator, @@ -135,11 +140,11 @@ def grid(meta): def grid(meta): num_n_blocks = triton.cdiv(n_size, meta["n_block_size"]) num_k_blocks = triton.cdiv(k_size, meta["k_block_size"]) - return (num_batches * num_n_blocks * num_k_blocks,) + return (num_n_blocks * num_k_blocks,) util.push_trace("kernel.Linear.backward_weight") kernel.Linear.backward_weight[grid]( - grad_weight_staging, + grad_weight, grad_state_gate, input, m_size, @@ -147,36 +152,41 @@ def grid(meta): k_size, input.stride(0), input.stride(1), - input.stride(2), use_accelerator, - util.dtype(grad_weight_staging.dtype), + util.dtype(grad_weight.dtype), ) util.pop_trace() - util.push_trace("torch.sum") - grad_weight = torch.sum(grad_weight_staging, 0) - util.pop_trace() - if bias is not None: - grad_bias_staging = torch.empty(num_batches, n_size, **factory_kwargs) + grad_bias = torch.empty(n_size, **factory_kwargs) def grid(meta): - return (num_batches * n_size,) + return (n_size,) util.push_trace("kernel.Linear.backward_bias") kernel.Linear.backward_bias[grid]( - grad_bias_staging, + grad_bias, grad_state_gate, m_size, n_size, - util.dtype(grad_bias_staging.dtype), + util.dtype(grad_bias.dtype), ) util.pop_trace() - util.push_trace("torch.sum") - grad_bias = torch.sum(grad_bias_staging, 0) - util.pop_trace() else: grad_bias = None return grad_input, grad_weight, grad_bias + + @staticmethod + def __input_shape(input: torch.Tensor): + return (-1, input.shape[-1]) + + @staticmethod + def __output_shape(input: torch.Tensor, weight: torch.Tensor): + if input.dim() == 2: + return (input.shape[0], weight.shape[0] // 2) + elif input.dim() == 3: + return (*input.shape[0:2], weight.shape[0] // 2) + else: + raise ValueError(f"Unable to convert the given input: '{input}'.") diff --git a/trident/operation/linear.py b/trident/operation/linear.py index 411f3562..8c202fbf 100644 --- a/trident/operation/linear.py +++ b/trident/operation/linear.py @@ -26,13 +26,19 @@ def forward(ctx: Any, *args: Any, **kwargs: Any): input, weight, bias, use_accelerator = args util.push_trace("Linear.__forward") - output = Linear.__forward(input, weight, bias, use_accelerator) + input_shape = Linear.__input_shape(input) + output = Linear.__forward( + input.view(input_shape) if input.is_contiguous() else input.reshape(input_shape), + weight, + bias, + use_accelerator, + ) util.pop_trace() ctx.save_for_backward(input, weight, bias) ctx.use_accelerator = use_accelerator - return output + return output.view(Linear.__output_shape(input, weight)) @staticmethod def backward(ctx: Any, *grad_outputs: Any): @@ -40,22 +46,24 @@ def backward(ctx: Any, *grad_outputs: Any): input, weight, bias = ctx.saved_tensors util.push_trace("Linear.__backward") - grad_input, grad_weight, grad_bias = Linear.__backward(grad_output, input, weight, bias, ctx.use_accelerator) + grad_input, grad_weight, grad_bias = Linear.__backward( + grad_output, input.view(Linear.__input_shape(input)), weight, bias, ctx.use_accelerator + ) util.pop_trace() - return grad_input, grad_weight, grad_bias, None, None + return grad_input.view(input.shape), grad_weight, grad_bias, None, None @staticmethod def __forward(input, weight, bias, use_accelerator): factory_kwargs = {"device": input.device, "dtype": input.dtype} - num_batches, m_size, k_size = input.shape + m_size, k_size = input.shape n_size, _ = weight.shape - output = torch.empty(num_batches, m_size, n_size, **factory_kwargs) + output = torch.empty(m_size, n_size, **factory_kwargs) def grid(meta): num_m_blocks = triton.cdiv(m_size, meta["m_block_size"]) num_n_blocks = triton.cdiv(n_size, meta["n_block_size"]) - return (num_batches * num_m_blocks * num_n_blocks,) + return (num_m_blocks * num_n_blocks,) util.push_trace("kernel.Linear.forward") kernel.Linear.forward[grid]( @@ -68,7 +76,6 @@ def grid(meta): k_size, input.stride(0), input.stride(1), - input.stride(2), weight.stride(0), weight.stride(1), use_accelerator, @@ -81,15 +88,15 @@ def grid(meta): @staticmethod def __backward(grad_output, input, weight, bias, use_accelerator): factory_kwargs = {"device": input.device, "dtype": input.dtype} - num_batches, m_size, k_size = input.shape + m_size, k_size = input.shape n_size, _ = weight.shape grad_input = torch.empty_like(input) - grad_weight_staging = torch.empty(num_batches, n_size, k_size, **factory_kwargs) + grad_weight = torch.empty(n_size, k_size, **factory_kwargs) def grid(meta): num_m_blocks = triton.cdiv(m_size, meta["m_block_size"]) num_k_blocks = triton.cdiv(k_size, meta["k_block_size"]) - return (num_batches * num_m_blocks * num_k_blocks,) + return (num_m_blocks * num_k_blocks,) util.push_trace("kernel.Linear.backward") kernel.Linear.backward[grid]( @@ -99,8 +106,8 @@ def grid(meta): m_size, n_size, k_size, + input.stride(0), input.stride(1), - input.stride(2), weight.stride(0), weight.stride(1), use_accelerator, @@ -111,11 +118,11 @@ def grid(meta): def grid(meta): num_n_blocks = triton.cdiv(n_size, meta["n_block_size"]) num_k_blocks = triton.cdiv(k_size, meta["k_block_size"]) - return (num_batches * num_n_blocks * num_k_blocks,) + return (num_n_blocks * num_k_blocks,) util.push_trace("kernel.Linear.backward_weight") kernel.Linear.backward_weight[grid]( - grad_weight_staging, + grad_weight, grad_output, input, m_size, @@ -123,36 +130,40 @@ def grid(meta): k_size, input.stride(0), input.stride(1), - input.stride(2), use_accelerator, - util.dtype(grad_weight_staging.dtype), + util.dtype(grad_weight.dtype), ) util.pop_trace() - util.push_trace("torch.sum") - grad_weight = torch.sum(grad_weight_staging, 0) - util.pop_trace() - if bias is not None: - grad_bias_staging = torch.empty(num_batches, n_size, **factory_kwargs) + grad_bias = torch.empty(n_size, **factory_kwargs) def grid(meta): - return (num_batches * n_size,) + return (n_size,) util.push_trace("kernel.Linear.backward_bias") kernel.Linear.backward_bias[grid]( - grad_bias_staging, + grad_bias, grad_output, m_size, n_size, - util.dtype(grad_bias_staging.dtype), + util.dtype(grad_bias.dtype), ) util.pop_trace() - - util.push_trace("torch.sum") - grad_bias = torch.sum(grad_bias_staging, 0) - util.pop_trace() else: grad_bias = None return grad_input, grad_weight, grad_bias + + @staticmethod + def __input_shape(input: torch.Tensor): + return (-1, input.shape[-1]) + + @staticmethod + def __output_shape(input: torch.Tensor, weight: torch.Tensor): + if input.dim() == 2: + return (input.shape[0], weight.shape[0]) + elif input.dim() == 3: + return (*input.shape[0:2], weight.shape[0]) + else: + raise ValueError(f"Unable to convert the given input: '{input}'.")