Skip to content
This repository was archived by the owner on Oct 16, 2023. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/test_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def train(func):
assert util.equal(y, b, 3e-01)

input = input.permute(0, 2, 1).reshape(num_batches, m_size, k_size)
weight = weight.permute(1, 0).reshape(n_size, k_size)

(x, y) = train(geglu)
(a, b) = train(trident.function.geglu)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_forward(num_batches, m_size, n_size, k_size, device):
assert util.equal(torch.nn.functional.linear(input, weight, bias), trident.function.linear(input, weight, bias))

input = input.permute(0, 2, 1)
weight = weight.permute(1, 0)
weight = torch.randn(k_size, n_size, device=device)

assert util.equal(torch.nn.functional.linear(input, weight), trident.function.linear(input, weight))

Expand All @@ -56,7 +56,6 @@ def train(func):
assert util.equal(y, b)

input = input.permute(0, 2, 1).reshape(num_batches, m_size, k_size)
weight = weight.permute(1, 0).reshape(n_size, k_size)

(x, y) = train(torch.nn.functional.linear)
(a, b) = train(trident.function.linear)
Expand Down
12 changes: 2 additions & 10 deletions trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ def geglu(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None,

See GEGLU for details.
"""
if input.dim() == 2:
output = operation.GEGLU.apply(input.view(1, *input.shape), weight, bias, use_accelerator)
return output.view(output.shape[1:3])
else:
return operation.GEGLU.apply(input, weight, bias, use_accelerator)
return operation.GEGLU.apply(input, weight, bias, use_accelerator)


def gelu(input: torch.Tensor):
Expand Down Expand Up @@ -156,11 +152,7 @@ def linear(

See Linear for more details.
"""
if input.dim() == 2:
output = operation.Linear.apply(input.view(1, *input.shape), weight, bias, use_accelerator)
return output.view(output.shape[1:3])
else:
return operation.Linear.apply(input, weight, bias, use_accelerator)
return operation.Linear.apply(input, weight, bias, use_accelerator)


def max(input: torch.Tensor, dim: int):
Expand Down
66 changes: 40 additions & 26 deletions trident/kernel/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class GEGLU:
@util.autotune(geglu_configs(), ["m_size", "k_size", "x_size"])
@triton.heuristics(
{
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"] == 0,
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"] == 0,
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"] == 0,
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
"require_k_boundary_check": lambda args: args["k_size"] % args["k_block_size"],
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"],
}
)
@triton.jit
Expand All @@ -71,7 +71,6 @@ def forward(
n_size: tl.int32,
k_size: tl.int32,
x_size: tl.int32,
input_batch_stride: tl.int32,
input_m_stride: tl.int32,
input_k_stride: tl.int32,
weight_n_stride: tl.int32,
Expand All @@ -89,31 +88,30 @@ def forward(
num_m_blocks = tl.cdiv(m_size, m_block_size)
num_x_blocks = tl.cdiv(x_size, x_block_size)
num_blocks = num_m_blocks * num_x_blocks
batch = pid // num_blocks
block = pid % num_blocks
m_block = block // num_x_blocks
x_block = block % num_x_blocks
m_offset = m_block * m_block_size
x_offset = x_block * x_block_size

output_block_ptr = tl.make_block_ptr(
output_ptr + batch * m_size * x_size,
output_ptr,
shape=(m_size, x_size),
strides=(x_size, 1),
offsets=(m_offset, x_offset),
block_shape=(m_block_size, x_block_size),
order=(1, 0),
)
state_block_ptr = tl.make_block_ptr(
state_gate_ptr + batch * m_size * n_size,
state_gate_ptr,
shape=(m_size, n_size),
strides=(n_size, 1),
offsets=(m_offset, x_offset),
block_shape=(m_block_size, x_block_size),
order=(1, 0),
)
gate_block_ptr = tl.make_block_ptr(
state_gate_ptr + batch * m_size * n_size,
state_gate_ptr,
shape=(m_size, n_size),
strides=(n_size, 1),
offsets=(m_offset, x_offset + x_size),
Expand All @@ -122,7 +120,7 @@ def forward(
)

state = language.Linear.forward(
input_ptr + batch * input_batch_stride,
input_ptr,
weight_ptr,
bias_ptr,
m_size,
Expand All @@ -144,7 +142,7 @@ def forward(
dtype,
)
gate = language.Linear.forward(
input_ptr + batch * input_batch_stride,
input_ptr,
weight_ptr,
bias_ptr,
m_size,
Expand All @@ -167,16 +165,21 @@ def forward(
)
output = state * language.math.GELU.forward(gate)

if require_m_boundary_check & require_x_boundary_check:
tl.store(output_block_ptr, output.to(dtype))
tl.store(state_block_ptr, state.to(dtype))
tl.store(gate_block_ptr, gate.to(dtype))
else:
if require_m_boundary_check | require_x_boundary_check:
tl.store(output_block_ptr, output.to(dtype), boundary_check=(0, 1))
tl.store(state_block_ptr, state.to(dtype), boundary_check=(0, 1))
tl.store(gate_block_ptr, gate.to(dtype), boundary_check=(0, 1))
else:
tl.store(output_block_ptr, output.to(dtype))
tl.store(state_block_ptr, state.to(dtype))
tl.store(gate_block_ptr, gate.to(dtype))

@staticmethod
@triton.heuristics(
{
"require_x_boundary_check": lambda args: args["x_size"] % args["x_block_size"],
}
)
@triton.jit
def backward(
grad_state_gate_ptr: tl.tensor,
Expand All @@ -187,56 +190,67 @@ def backward(
x_size: tl.int32,
dtype: tl.constexpr,
x_block_size: tl.constexpr,
require_x_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
batch = pid // m_size
m_offset = pid % m_size

grad_state_block_ptr = tl.make_block_ptr(
grad_state_gate_ptr + batch * m_size * n_size,
grad_state_gate_ptr,
shape=(m_size, n_size),
strides=(n_size, 1),
offsets=(m_offset, 0),
block_shape=(1, x_block_size),
order=(1, 0),
)
grad_gate_block_ptr = tl.make_block_ptr(
grad_state_gate_ptr + batch * m_size * n_size,
grad_state_gate_ptr,
shape=(m_size, n_size),
strides=(n_size, 1),
offsets=(m_offset, x_size),
block_shape=(1, x_block_size),
order=(1, 0),
)
grad_output_block_ptr = tl.make_block_ptr(
grad_output_ptr + batch * m_size * x_size,
grad_output_ptr,
shape=(m_size, x_size),
strides=(x_size, 1),
offsets=(m_offset, 0),
block_shape=(1, x_block_size),
order=(1, 0),
)
state_block_ptr = tl.make_block_ptr(
state_gate_ptr + batch * m_size * n_size,
state_gate_ptr,
shape=(m_size, n_size),
strides=(n_size, 1),
offsets=(m_offset, 0),
block_shape=(1, x_block_size),
order=(1, 0),
)
gate_block_ptr = tl.make_block_ptr(
state_gate_ptr + batch * m_size * n_size,
state_gate_ptr,
shape=(m_size, n_size),
strides=(n_size, 1),
offsets=(m_offset, x_size),
block_shape=(1, x_block_size),
order=(1, 0),
)

grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
state = tl.load(state_block_ptr, boundary_check=(1,))
gate = tl.load(gate_block_ptr, boundary_check=(1,))
if require_x_boundary_check:
grad_output = tl.load(grad_output_block_ptr, boundary_check=(1,))
state = tl.load(state_block_ptr, boundary_check=(1,))
gate = tl.load(gate_block_ptr, boundary_check=(1,))
else:
grad_output = tl.load(grad_output_block_ptr)
state = tl.load(state_block_ptr)
gate = tl.load(gate_block_ptr)

grad_state = grad_output * language.math.GELU.forward(gate)
grad_gate = language.math.GELU.backward(grad_output * state, gate)
tl.store(grad_state_block_ptr, grad_state.to(dtype), boundary_check=(1,))
tl.store(grad_gate_block_ptr, grad_gate.to(dtype), boundary_check=(1,))

if require_x_boundary_check:
tl.store(grad_state_block_ptr, grad_state.to(dtype), boundary_check=(1,))
tl.store(grad_gate_block_ptr, grad_gate.to(dtype), boundary_check=(1,))
else:
tl.store(grad_state_block_ptr, grad_state.to(dtype))
tl.store(grad_gate_block_ptr, grad_gate.to(dtype))
41 changes: 16 additions & 25 deletions trident/kernel/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def linear_configs_for_backward_bias():

class Linear:
@staticmethod
@util.autotune(linear_configs([16, 64, 128], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"])
@util.autotune(linear_configs([16, 64, 128, 256], [32, 64, 128], [32, 64]), ["m_size", "n_size", "k_size"])
@triton.heuristics(
{
"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"],
Expand All @@ -128,7 +128,6 @@ def forward(
m_size: tl.int32,
n_size: tl.int32,
k_size: tl.int32,
input_batch_stride: tl.int32,
input_m_stride: tl.int32,
input_k_stride: tl.int32,
weight_n_stride: tl.int32,
Expand All @@ -143,18 +142,14 @@ def forward(
require_k_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
num_m_blocks = tl.cdiv(m_size, m_block_size)
num_n_blocks = tl.cdiv(n_size, n_block_size)
num_blocks = num_m_blocks * num_n_blocks
batch = pid // num_blocks
block = pid % num_blocks
m_block = block // num_n_blocks
n_block = block % num_n_blocks
m_block = pid // num_n_blocks
n_block = pid % num_n_blocks
m_offset = m_block * m_block_size
n_offset = n_block * n_block_size

output = language.Linear.forward(
input_ptr + batch * input_batch_stride,
input_ptr,
weight_ptr,
bias_ptr,
m_size,
Expand All @@ -177,7 +172,7 @@ def forward(
)

output_block_ptr = tl.make_block_ptr(
output_ptr + batch * m_size * n_size,
output_ptr,
shape=(m_size, n_size),
strides=(n_size, 1),
offsets=(m_offset, n_offset),
Expand Down Expand Up @@ -223,15 +218,14 @@ def backward(
num_m_blocks = tl.cdiv(m_size, m_block_size)
num_k_blocks = tl.cdiv(k_size, k_block_size)
num_blocks = num_m_blocks * num_k_blocks
batch = pid // num_blocks
block = pid % num_blocks
m_block = block // num_k_blocks
k_block = block % num_k_blocks
m_offset = m_block * m_block_size
k_offset = k_block * k_block_size

grad_input = language.Linear.backward(
grad_output_ptr + batch * m_size * n_size,
grad_output_ptr,
weight_ptr,
m_size,
n_size,
Expand All @@ -251,7 +245,7 @@ def backward(
)

grad_input_block_ptr = tl.make_block_ptr(
grad_input_ptr + batch * m_size * k_size,
grad_input_ptr,
shape=(m_size, k_size),
strides=(input_m_stride, input_k_stride),
offsets=(m_offset, k_offset),
Expand All @@ -278,13 +272,12 @@ def backward(
)
@triton.jit
def backward_weight(
grad_weight_staging_ptr: tl.tensor,
grad_weight_ptr: tl.tensor,
grad_output_ptr: tl.tensor,
input_ptr: tl.tensor,
m_size: tl.int32,
n_size: tl.int32,
k_size: tl.int32,
input_batch_stride: tl.int32,
input_m_stride: tl.int32,
input_k_stride: tl.int32,
use_accelerator: tl.constexpr,
Expand All @@ -300,16 +293,15 @@ def backward_weight(
num_n_blocks = tl.cdiv(n_size, n_block_size)
num_k_blocks = tl.cdiv(k_size, k_block_size)
num_blocks = num_n_blocks * num_k_blocks
batch = pid // num_blocks
block = pid % num_blocks
n_block = block // num_k_blocks
k_block = block % num_k_blocks
n_offset = n_block * n_block_size
k_offset = k_block * k_block_size

grad_weight = language.Linear.backward_weight(
grad_output_ptr + batch * m_size * n_size,
input_ptr + batch * input_batch_stride,
grad_output_ptr,
input_ptr,
m_size,
n_size,
k_size,
Expand All @@ -328,7 +320,7 @@ def backward_weight(
)

grad_weight_staging_block_ptr = tl.make_block_ptr(
grad_weight_staging_ptr + batch * n_size * k_size,
grad_weight_ptr,
shape=(n_size, k_size),
strides=(k_size, 1),
offsets=(n_offset, k_offset),
Expand All @@ -346,7 +338,7 @@ def backward_weight(
@triton.heuristics({"require_m_boundary_check": lambda args: args["m_size"] % args["m_block_size"]})
@triton.jit
def backward_bias(
grad_bias_staging_ptr: tl.tensor,
grad_bias_ptr: tl.tensor,
grad_output_ptr: tl.tensor,
m_size: tl.int32,
n_size: tl.int32,
Expand All @@ -355,10 +347,9 @@ def backward_bias(
require_m_boundary_check: tl.constexpr,
):
pid = tl.program_id(0)
batch = pid // n_size
n_offset = pid % n_size
grad_bias = language.Linear.backward_bias(
grad_output_ptr + batch * m_size * n_size,
grad_output_ptr,
m_size,
n_size,
n_offset,
Expand All @@ -367,12 +358,12 @@ def backward_bias(
dtype,
)

grad_bias_staging_block_ptr = tl.make_block_ptr(
grad_bias_staging_ptr + batch * n_size,
grad_bias_block_ptr = tl.make_block_ptr(
grad_bias_ptr,
shape=(n_size,),
strides=(1,),
offsets=(n_offset,),
block_shape=(1,),
order=(0,),
)
tl.store(grad_bias_staging_block_ptr, grad_bias)
tl.store(grad_bias_block_ptr, grad_bias)
Loading