diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..6cef360 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -36,6 +36,11 @@ softmax, sub, tanh, + bitwise_left_shift, + index_select, + fold, + mish, + log2, ) __all__ = [ @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "bitwise_left_shift", + "index_select", + "fold", + "mish", + "log2", ] diff --git a/src/ntops/kernels/bitwise_left_shift.py b/src/ntops/kernels/bitwise_left_shift.py new file mode 100644 index 0000000..9104e59 --- /dev/null +++ b/src/ntops/kernels/bitwise_left_shift.py @@ -0,0 +1,33 @@ +import functools + +from ninetoothed import Tensor +import ninetoothed.language as ntl + +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + if input.dtype == ntl.int32: + mask = (other > 31) | (other < 0) + elif input.dtype == ntl.int64: + mask = (other > 63) | (other < 0) + elif input.dtype == ntl.uint8: + mask = (other > 7) | (other < 0) + else: + mask = ntl.zeros_like(other) + + shift = ntl.where(mask, ntl.zeros_like(other), other) + input = ntl.where(mask, ntl.zeros_like(input), input) + output = input << shift + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/fold.py b/src/ntops/kernels/fold.py new file mode 100644 index 0000000..30a3a5b --- /dev/null +++ b/src/ntops/kernels/fold.py @@ -0,0 +1,42 @@ +import functools + +from ninetoothed import Tensor +import ninetoothed.language as ntl + +def arrangement(*tensors, L_pow2, kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, dilation_w, padding_h, padding_w, block_size=None): + # input: (N, C * k_w * k_h, H_out * W_out) + # output: (N, C, H_in, W_in) + input, output, L_val = tensors + + # 排布 output, 使其与 input 对齐 + output = output.tile((1, 1, kernel_size_h, kernel_size_w), (1, 1, stride_h, stride_w), (1, 1, dilation_h, dilation_w)) + # => output: (N, C, H_out, W_out), dtype=(1, 1, k_h, k_w) + output = output.ravel() # => output: (N, C, H_out, W_out, 1, 1, k_h, k_w) + output = output.permute((0, 1, 4, 5, 6, 7, 2, 3)) + # => output: (N, C, 1, 1, k_h, k_w, H_out, W_out) + output = output.flatten(start_dim=0, end_dim=6).flatten(start_dim=1) + # => output: (N * C * k_h * k_w, H_out * W_out) + output = output.tile((block_size, L_pow2)).squeeze(1) + # => output: (... // block_size, ), dtype=(block_size, L_pow2) + + input = input.flatten(end_dim=2) # => input: (N * C * k_h * k_w, H_out * W_out) + input = input.tile((block_size, L_pow2)).squeeze(1) + # => input: (... // block_size), dtype=(block_size, L_pow2) + + return input, output, L_val + +def application(input, output, L): + # input: (block_size, L_pow2) + # output: (block_size, L_pow2) + ntl.atomic_add(output.data_ptr() + output.offsets(), input) + +def premake(L_pow2, kernel_size_h, kernel_size_w, stride_h, stride_w, dilation_h, dilation_w, padding_h, padding_w, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, L_pow2=L_pow2, kernel_size_h=kernel_size_h, kernel_size_w=kernel_size_w, stride_h=stride_h, stride_w=stride_w, dilation_h=dilation_h, dilation_w=dilation_w, padding_h=padding_h, padding_w=padding_w, block_size=block_size) + + tensors = ( + Tensor(3, dtype=dtype, other=0, shape_options={'constexpr': True}), + Tensor(4, dtype=dtype, other=0, shape_options={'constexpr': True}), + Tensor(0, dtype=int, constexpr=True), # L + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/index_select.py b/src/ntops/kernels/index_select.py new file mode 100644 index 0000000..f7f836b --- /dev/null +++ b/src/ntops/kernels/index_select.py @@ -0,0 +1,83 @@ + +import functools + +from ninetoothed import Tensor +import ninetoothed.language as ntl + +def arrangement(input, output, index, T, S, T_pow2, S_pow2, dim, block_size=None): + non_target_dim = tuple(i for i in range(input.ndim) if i != dim) + input = input.permute(non_target_dim + (dim,)) + input = input.flatten(end_dim=-1) # shape: (..., T) + + output = output.permute(non_target_dim + (dim,)) + output = output.flatten(end_dim=-1) # shape: (..., S) + + # input: (..., T) + # output: (..., S) + # index: (S,) + input_tiled = input.tile((block_size, T_pow2)).squeeze(1) # shape: (..., ), dtype=(block_size, T_pow2) + output_tiled = output.tile((block_size, S_pow2)).squeeze(1) # shape: (..., ), dtype=(block_size, S_pow2) + + index_expand = index.unsqueeze(0).expand((input_tiled.shape[0], -1)) # shape: (..., S) + index_expand = index_expand.tile((1, S_pow2)).squeeze(1) # shape: (..., ), dtype=(1, S_pow2) + + return input_tiled, output_tiled, index_expand, T, S + +# def application(input, output, index): +# # input: (block_size, T) +# # output: (block_size, S) +# # index: (1, S) +# # 使用 gather 实现 index_select +# # Triton 3.0.0 不支持 gather 操作,因此在摩尔线程中无法使用 +# # 这里仅作为参考 +# index_expand = ntl.broadcast_to(index, (input.shape[0], index.shape[1])) +# # index_expand: (block_size, S) +# output = ntl.gather(input, index, axis=1) + +def application(input, output, index, T, S): + # input: (block_size, T_pow2) + # output: (block_size, S_pow2) + # index: (1, S_pow2) + + # 使用 T_pow2 满足 arange 的 2 次幂要求 + col_indices = ntl.arange(0, input.shape[1]) # shape: (T_pow2,) + + # 添加维度并广播到 (block_size, S, T_pow2) + col_indices = ntl.expand_dims(col_indices, 0) # shape: (1, T_pow2) + col_indices = ntl.expand_dims(col_indices, 0) # shape: (1, 1, T_pow2) + col_indices = ntl.broadcast_to(col_indices, (input.shape[0], output.shape[1], input.shape[1])) + + # 扩展 input 到 (block_size, S, T_pow2) + input_expanded = ntl.expand_dims(input, 1) # shape: (block_size, 1, T_pow2) + input_expanded = ntl.broadcast_to(input_expanded, (input.shape[0], output.shape[1], input.shape[1])) + + # 扩展 index 到 (block_size, S, T_pow2) + index_expanded = ntl.expand_dims(index, 2) # shape: (block_size, S, 1) + index_expanded = ntl.broadcast_to(index_expanded, (input.shape[0], output.shape[1], input.shape[1])) + + # 仅在有效列范围内匹配,超出原始 T 的部分屏蔽 + col_valid = col_indices < input.shape[1] + match_mask = (col_indices == index_expanded) + mask = ntl.where(col_valid, match_mask, False) + + # 使用 where 选择对应的值 + selected = ntl.where(mask, input_expanded, 0.0) # shape: (block_size, S, T_pow2) + + # 对最后一个维度求和得到结果 + result = ntl.sum(selected, axis=2) # shape: (block_size, S) + + # 写回输出 + output = result + +def premake(ndim, dim, T_pow2, S_pow2, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, T_pow2=T_pow2, S_pow2=S_pow2, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype, other=0, shape_options={'constexpr': True}), + Tensor(ndim, dtype=dtype, other=0, shape_options={'constexpr': True}), + Tensor(1, dtype=int, shape_options={'constexpr': True}), + Tensor(0, dtype=int, constexpr=True), # T + Tensor(0, dtype=int, constexpr=True), # S + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/log2.py b/src/ntops/kernels/log2.py new file mode 100644 index 0000000..e54255d --- /dev/null +++ b/src/ntops/kernels/log2.py @@ -0,0 +1,20 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + dtype = input.dtype + log2_dtype = dtype if dtype != ntl.float16 else ntl.float32 + output = ntl.cast(ntl.log2(ntl.cast(input, log2_dtype)), dtype) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/mish.py b/src/ntops/kernels/mish.py new file mode 100644 index 0000000..15e10d2 --- /dev/null +++ b/src/ntops/kernels/mish.py @@ -0,0 +1,37 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def _softplus(x): + return ntl.log(ntl.exp(-ntl.abs(x)) + 1.0) + ntl.maximum(x, 0.0) + + +def _tanh(x): + return (ntl.exp(2 * x) - 1) / (ntl.exp(2 * x) + 1) + + +def application(input, output): + dtype = input.dtype + if dtype == ntl.float16: + mish_dtype = ntl.float32 + elif dtype == ntl.bfloat16: + mish_dtype = ntl.float32 + else: + mish_dtype = dtype + + input_f32 = ntl.cast(input, mish_dtype) + output_softplus_f32 = _softplus(input_f32) + output_f32 = _tanh(output_softplus_f32) + output = ntl.cast(output_f32 * input_f32, dtype) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..fee7b03 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -36,6 +36,11 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.bitwise_left_shift import bitwise_left_shift +from ntops.torch.index_select import index_select +from ntops.torch.fold import fold +from ntops.torch.mish import mish +from ntops.torch.log2 import log2 __all__ = [ "abs", @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "bitwise_left_shift", + "index_select", + "fold", + "mish", + "log2", ] diff --git a/src/ntops/torch/bitwise_left_shift.py b/src/ntops/torch/bitwise_left_shift.py new file mode 100644 index 0000000..7997d6f --- /dev/null +++ b/src/ntops/torch/bitwise_left_shift.py @@ -0,0 +1,33 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def bitwise_left_shift(input, other, *, out=None): + # Check if we need to handle non-contiguous inplace operation + is_inplace_input = out is not None and out.data_ptr() == input.data_ptr() + + if out is None: + out = torch.empty_like(input) + + # 处理非连续张量的原地操作特殊情况: + # 当 out 和 input 是同一个张量(原地操作)且 input 具有非标准 strides(非连续)时, + # ninetoothed 框架中的 element_wise.arrangement 函数使用 flatten() 会丢失内存布局信息, + # 导致 GPU kernel 无法正确将结果写回到具有特殊 strides 的原始张量中。 + # 解决方案是先将输入转换为连续张量进行计算,然后使用 copy_() 将结果复制回原始张量, + # copy_() 方法会正确处理目标张量的 strides,确保数据被写入到正确的内存位置。 + if is_inplace_input and not input.is_contiguous(): + input_contig = input.contiguous() + other_contig = other.contiguous() if not other.is_contiguous() else other + out_contig = torch.empty_like(input_contig) + + kernel = _cached_make(ntops.kernels.bitwise_left_shift.premake, input.ndim) + kernel(input_contig, other_contig, out_contig) + + out.copy_(out_contig) + else: + kernel = _cached_make(ntops.kernels.bitwise_left_shift.premake, input.ndim) + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/fold.py b/src/ntops/torch/fold.py new file mode 100644 index 0000000..93dad63 --- /dev/null +++ b/src/ntops/torch/fold.py @@ -0,0 +1,88 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + +def fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1): + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(dilation, int): + dilation = (dilation, dilation) + if isinstance(padding, int): + padding = (padding, padding) + if isinstance(stride, int): + stride = (stride, stride) + if isinstance(output_size, int): + output_size = (output_size, output_size) + + # 记录原始输入维度 + input_was_2d = input.ndim == 2 + if input_was_2d: + input = input.view((1, input.shape[0], input.shape[1])) + + N, Ckk, L = input.shape + H_out, W_out = output_size + K_h, K_w = kernel_size + D_h, D_w = dilation + P_h, P_w = padding + S_h, S_w = stride + + # 验证和计算 L + C = Ckk // (K_h * K_w) + if C * K_h * K_w != Ckk: + raise ValueError(f"Input channel dimension {Ckk} is not divisible by kernel size product {K_h * K_w}") + + L_h = (H_out + 2 * P_h - (D_h * (K_h - 1) + 1)) // S_h + 1 + L_w = (W_out + 2 * P_w - (D_w * (K_w - 1) + 1)) // S_w + 1 + if L != L_h * L_w: + raise ValueError(f"Input L {L} != computed L_h*L_w {L_h * L_w}") + + # 创建带 padding 的输出张量 + out_padded_h = H_out + 2 * P_h + out_padded_w = W_out + 2 * P_w + out = torch.empty( + (N, C, out_padded_h, out_padded_w), + dtype=input.dtype, + device=input.device + ) + torch.nn.init.zeros_(out) + + # 创建并调用 kernel + block_size = 128 + L_pow2 = 1 << (L - 1).bit_length() + kernel = _cached_make( + ntops.kernels.fold.premake, + L_pow2, + kernel_size[0], + kernel_size[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + padding[0], + padding[1], + dtype=input.dtype, + block_size=block_size + ) + kernel(input, out, L) + + # 移除 padding + result = out + if P_h > 0 or P_w > 0: + # 目前不支持直接切片,只能用 narrow 实现 + result = torch.narrow(result, 2, P_h, H_out) + result = torch.narrow(result, 3, P_w, W_out) + + # 由于 ninetoothed 框架下难以实现原地 padding 的操作,因此这里创建新张量 + # 创建新张量接收结果,确保内存连续 + output = torch.empty( + (N, C, H_out, W_out), + dtype=input.dtype, + device=input.device) + torch.nn.init.zeros_(output) + torch.add(output, result, out=output) + + if input_was_2d: + output = output.view((output.shape[1], output.shape[2], output.shape[3])) + + return output \ No newline at end of file diff --git a/src/ntops/torch/index_select.py b/src/ntops/torch/index_select.py new file mode 100644 index 0000000..240acac --- /dev/null +++ b/src/ntops/torch/index_select.py @@ -0,0 +1,27 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def index_select(input, dim, index, *, out=None): + assert index.ndim == 1, "Index tensor must be 1-dimensional." + + T = input.shape[dim] + T_pow2 = 1 << (T - 1).bit_length() + S = index.shape[0] + S_pow2 = 1 << (S - 1).bit_length() + + if dim < 0: + dim += input.ndim + + if out is None: + output_shape = list(input.shape) + output_shape[dim] = index.shape[0] + out = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + block_size = 256 + kernel = _cached_make(ntops.kernels.index_select.premake, input.ndim, dim, T_pow2=T_pow2, S_pow2=S_pow2, block_size=block_size) + kernel(input, out, index, T, S) + + return out diff --git a/src/ntops/torch/log2.py b/src/ntops/torch/log2.py new file mode 100644 index 0000000..c978b07 --- /dev/null +++ b/src/ntops/torch/log2.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def log2(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.log2.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/mish.py b/src/ntops/torch/mish.py new file mode 100644 index 0000000..f17d5fd --- /dev/null +++ b/src/ntops/torch/mish.py @@ -0,0 +1,17 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def mish(input, inplace=False): + if not inplace: + out = torch.empty_like(input) + else: + out = input + + kernel = _cached_make(ntops.kernels.mish.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/tests/test_bitwise_left_shift.py b/tests/test_bitwise_left_shift.py new file mode 100644 index 0000000..69854ef --- /dev/null +++ b/tests/test_bitwise_left_shift.py @@ -0,0 +1,32 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments(False)) +def test_bitwise_left_shift(shape, dtype, device, rtol, atol): + if dtype == torch.bool: + return + elif dtype == torch.int8: + # 这里只支持 uint-8 + dtype = torch.uint8 + upper_bound = 10 + input = torch.randint(0, upper_bound, size=shape, dtype=dtype, device=device) + other = torch.randint(0, upper_bound, size=shape, dtype=dtype, device=device) + else: + upper_bound = 10 + input = torch.randint( + -upper_bound, upper_bound, size=shape, dtype=dtype, device=device + ) + other = torch.randint( + -upper_bound, upper_bound, size=shape, dtype=dtype, device=device + ) + + ninetoothed_output = ntops.torch.bitwise_left_shift(input, other) + reference_output = torch.bitwise_left_shift(input, other) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_fold.py b/tests/test_fold.py new file mode 100644 index 0000000..3200f49 --- /dev/null +++ b/tests/test_fold.py @@ -0,0 +1,61 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +# Test cases: (in_shape, output_size, kernel_size, dilation, stride, padding) +_TEST_CASES = [ + ((2, 27, 36), (8, 8), (3, 3), (1, 1), (1, 1), (0, 0)), + ((2, 32, 16), (16, 16), (4, 4), (1, 1), (4, 4), (0, 0)), + ((3, 36, 40), (7, 9), (3, 2), (1, 1), (1, 1), (0, 0)), + ((2, 45, 20), (12, 6), (3, 3), (1, 1), (2, 1), (0, 0)), + + # padding 在 infinicore 层面处理 + # 原来对应 ((1,4,10,12), None, (5,3), 1, 1, (2,1)) + # L = 4 * 12 = 48, channels = 4*5*3 = 60 + # ((1, 60, 48), (10, 12), (5, 3), (1, 1), (2, 1), (1, 1)), + # 原来对应 ((1,8,9,11), None, (2,3), 1, 1, (1,2)) + # L = 10 * 6 = 60, channels = 8*2*3 = 48 + # ((1, 48, 60), (9, 11), (2, 3), (1, 1), (1, 2), (1, 1)), +] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize( + "in_shape, output_size, kernel_size, dilation, stride, padding", + _TEST_CASES, +) +def test_fold(in_shape, output_size, kernel_size, dilation, stride, padding, dtype): + device = "cuda" + + x = torch.randn(*in_shape, dtype=dtype, device=device) + + reference_output = torch.nn.functional.fold( + x, + output_size=output_size, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + + ninetoothed_output = ntops.torch.fold( + x, + output_size=output_size, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride, + ) + + if dtype is torch.float32: + atol = 0.001 + rtol = 0.001 + else: + atol = 0.01 + rtol = 0.01 + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_index_select.py b/tests/test_index_select.py new file mode 100644 index 0000000..8ec545d --- /dev/null +++ b/tests/test_index_select.py @@ -0,0 +1,59 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize( + "shape", + [ + (8, 8), + (4, 5, 6), + (2, 3, 4, 5), + (10, 20), + ], +) +def test_index_select(shape): + """Test index_select with float32 input""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Create input tensor (float32 only) + input = torch.randn(shape, dtype=torch.float32, device=device) + + # Test index_select on different dimensions + for dim in range(len(shape)): + # Create random indices for selection + num_indices = torch.randint(1, shape[dim] + 1, (1,)).item() + indices = torch.randperm(shape[dim], device=device)[:num_indices] + + # Call ntops implementation + ninetoothed_output = ntops.torch.index_select(input, dim, indices) + + # Call reference implementation + reference_output = torch.index_select(input, dim, indices) + + # Compare results + assert torch.allclose(ninetoothed_output, reference_output), ( + f"Mismatch for shape={shape}, dim={dim}, num_indices={num_indices}" + ) + + +# @skip_if_cuda_not_available +# def test_index_select_all_indices(): +# """Test index_select with all indices (should equal identity, float32)""" +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# shape = (3, 4, 5) +# input = torch.randn(shape, dtype=torch.float32, device=device) +# +# # Test with all indices in order +# for dim in range(len(shape)): +# indices = torch.arange(shape[dim], device=device, dtype=torch.int64) +# +# ninetoothed_output = ntops.torch.index_select(input, dim, indices) +# reference_output = torch.index_select(input, dim, indices) +# +# assert torch.allclose(ninetoothed_output, reference_output), ( +# f"Mismatch with all indices for dim={dim}" +# ) diff --git a/tests/test_log2.py b/tests/test_log2.py new file mode 100644 index 0000000..1fba4b7 --- /dev/null +++ b/tests/test_log2.py @@ -0,0 +1,21 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_log2(shape, dtype, device, rtol, atol): + # TODO: Test for `float16` later. + if dtype is torch.float16: + return + # Use positive values to avoid log of negative numbers + input = torch.abs(torch.randn(shape, dtype=dtype, device=device)) + 1e-6 + + ninetoothed_output = ntops.torch.log2(input) + reference_output = torch.log2(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_mish.py b/tests/test_mish.py new file mode 100644 index 0000000..83393fc --- /dev/null +++ b/tests/test_mish.py @@ -0,0 +1,18 @@ +import pytest +import torch +import torch.nn.functional as F + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_mish(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.mish(input) + reference_output = F.mish(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)