From 4b5e9e052262b5c80e6784158dd987d7adb179ed Mon Sep 17 00:00:00 2001 From: greenhandhand <781740145@qq.com> Date: Sun, 14 Dec 2025 15:41:23 +0800 Subject: [PATCH 1/2] =?UTF-8?q?Infinicore=20=E6=AF=94=E8=B5=9B=EF=BC=8C?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20log10,=20avg=5Fpool3d,=20histc,=20dot,=20l?= =?UTF-8?q?og1p=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ntops/kernels/__init__.py | 10 +++ src/ntops/kernels/avg_pool3d.py | 101 ++++++++++++++++++++++++ src/ntops/kernels/dot.py | 73 ++++++++++++++++++ src/ntops/kernels/histc.py | 133 ++++++++++++++++++++++++++++++++ src/ntops/kernels/log10.py | 21 +++++ src/ntops/kernels/log1p.py | 18 +++++ src/ntops/torch/__init__.py | 10 +++ src/ntops/torch/avg_pool3d.py | 54 +++++++++++++ src/ntops/torch/dot.py | 33 ++++++++ src/ntops/torch/histc.py | 31 ++++++++ src/ntops/torch/log10.py | 16 ++++ src/ntops/torch/log1p.py | 16 ++++ 12 files changed, 516 insertions(+) create mode 100644 src/ntops/kernels/avg_pool3d.py create mode 100644 src/ntops/kernels/dot.py create mode 100644 src/ntops/kernels/histc.py create mode 100644 src/ntops/kernels/log10.py create mode 100644 src/ntops/kernels/log1p.py create mode 100644 src/ntops/torch/avg_pool3d.py create mode 100644 src/ntops/torch/dot.py create mode 100644 src/ntops/torch/histc.py create mode 100644 src/ntops/torch/log10.py create mode 100644 src/ntops/torch/log1p.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..8f79be9 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -36,6 +36,11 @@ softmax, sub, tanh, + avg_pool3d, + histc, + log10, + log1p, + dot, ) __all__ = [ @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "avg_pool3d", + "histc", + "log10", + "log1p", + "dot", ] diff --git a/src/ntops/kernels/avg_pool3d.py b/src/ntops/kernels/avg_pool3d.py new file mode 100644 index 0000000..77391d4 --- /dev/null +++ b/src/ntops/kernels/avg_pool3d.py @@ -0,0 +1,101 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed.language import libdevice +from ninetoothed import Tensor +from ninetoothed import Symbol + +def arrangement( + *tensors, + kernel_size_d, + kernel_size_h, + kernel_size_w, + stride_d, + stride_h, + stride_w, + block_size, + ceil_mode, +): + input, output, kernel_volume = tensors + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, C, D_in, H_in, W_in) output: (N, C, D_out, H_out, W_out) + input_arranged = input.tile( + (1, 1, kernel_size_d, kernel_size_h, kernel_size_w), + (1, 1, stride_d, stride_h, stride_w), + floor_mode=not ceil_mode, + ) + # => (N, C, D_out, H_out, W_out), dtype=(1, 1, k_d, k_h, k_w) + input_arranged = input_arranged.ravel() + # => (N, C, D_out, H_out, W_out, 1, 1, k_d, k_h, k_w) + input_arranged = input_arranged.flatten(end_dim=5).flatten(start_dim=1) + # => (N*C*D_out*H_out*W_out, k_d*k_h*k_w) + + # k_d*k_h*k_w 的找到最近的 2 的倍数 + nearest_pow2 = 1 << (kernel_size_d * kernel_size_h * kernel_size_w - 1).bit_length() + input_arranged = input_arranged.tile((1, nearest_pow2)) + # => (..., k_d*k_h*k_w // nearest_pow2 = 1), dtype=(1, nearest_pow2) + input_arranged.dtype = input_arranged.dtype.squeeze(0) + # => (..., 1), dtype=(nearest_pow2, ) + input_arranged = input_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1), dtype=(nearest_pow2, ) + input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1) + # => (..., 1), dtype=(block_size, nearest_pow2) + + output_arranged = output.tile((1, 1, 1, 1, 1)) + # => (N, C, D_out, H_out, W_out), dtype=(1, 1, 1, 1, 1) + output_arranged = output_arranged.ravel() + # => (N, C, D_out, H_out, W_out, 1, 1, 1, 1) + output_arranged = output_arranged.flatten(end_dim=5).flatten(start_dim=1) + # => (N*C*D_out*H_out*W_out, 1) + output_arranged = output_arranged.tile((block_size, -1)) + # => (..., 1), dtype=(block_size, 1) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + # => (..., 1), dtype=(block_size, ) + + return input_arranged, output_arranged, kernel_volume + + +def application(input, output, kernel_volume): + # input: (block_size, nearest_pow2) + # output: (block_size,) + + # Input 数据: (block_size, nearest_pow2) + # 这是实际的像素值,越界处填充为 0 + val_sum = ntl.sum(input, axis=1) # (block_size, ) + output = val_sum / kernel_volume # (block_size, ) + + +def premake( + ndim, + kernel_size_d, + kernel_size_h, + kernel_size_w, + stride_d, + stride_h, + stride_w, + block_size=None, + ceil_mode=False, + dtype=None, +): + arrangement_ = functools.partial( + arrangement, + kernel_size_d=kernel_size_d, + kernel_size_h=kernel_size_h, + kernel_size_w=kernel_size_w, + stride_d=stride_d, + stride_h=stride_h, + stride_w=stride_w, + block_size=block_size, + ceil_mode=ceil_mode, + ) + + tensors = ( + Tensor(ndim, dtype=dtype, other=0), # input + Tensor(ndim, dtype=dtype), # output + Tensor(0, dtype=dtype), # kernel_volume + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/dot.py b/src/ntops/kernels/dot.py new file mode 100644 index 0000000..09e320f --- /dev/null +++ b/src/ntops/kernels/dot.py @@ -0,0 +1,73 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + +def arrangement_dot_full(input, tensor, out, block_size): + # input/tensor: (N, ) + # output: (1, ) + input = input.tile((-1, )) # (1, ), dtype=(block_size, ) + tensor = tensor.tile((-1, )) # (1, ), dtype=(block_size, ) + out = out.tile((1, )) # (1, ), dtype=(1, ) + return input, tensor, out + +def application_dot_full(input, tensor, out): + out = ntl.sum(input * tensor) + +def premake_dot_full(dtype, block_size): + arrangement_ = functools.partial(arrangement_dot_full, block_size=block_size) + + tensors = ( + Tensor(1, dtype=dtype, shape_options={'constexpr': True}), + Tensor(1, dtype=dtype, shape_options={'constexpr': True}), + Tensor(1, dtype=dtype) + ) + + return arrangement_, application_dot_full, tensors + + +# ========= 分块计算 ========= + +def arrangement_dot_divide(input, tensor, out_temp, block_size): + # input/tensor: (N, ) + # output: (N // block_size, ) + input = input.tile((block_size, )) # (N // block_size, block_size), dtype=(block_size, ) + tensor = tensor.tile((block_size, )) # (N // block_size, block_size), dtype=(block_size, ) + out_temp = out_temp.tile((1, )) # (N // block_size, ), dtype=(1, ) + return input, tensor, out_temp + +def application_dot_divide(input, tensor, out_temp): + out_temp = ntl.sum(input * tensor, 0) + +def arrangement_dot_conquer(input_block_wise, out, block_size): + # input/tensor: (N // block_size, ) + # output: (1, ) + input_block_wise = input_block_wise.tile((-1, )) # (1, ), dtype=(block_size, ) + out = out.tile((1, )) # (1, ), dtype=(1, ) + return input_block_wise, out + +def application_dot_conquer(input_block_wise, out): + out = ntl.sum(input_block_wise) + +def premake_dot_divide(dtype, block_size): + arrangement_ = functools.partial(arrangement_dot_divide, block_size=block_size) + + tensors = ( + Tensor(1, dtype=dtype, shape_options={'constexpr': True}), + Tensor(1, dtype=dtype, shape_options={'constexpr': True}), + Tensor(1, dtype=dtype) + ) + + return arrangement_, application_dot_divide, tensors + +def premake_dot_conquer(dtype, block_size): + arrangement_ = functools.partial(arrangement_dot_conquer, block_size=block_size) + + tensors = ( + Tensor(1, dtype=dtype, shape_options={'constexpr': True}), + Tensor(1, dtype=dtype) + ) + + return arrangement_, application_dot_conquer, tensors diff --git a/src/ntops/kernels/histc.py b/src/ntops/kernels/histc.py new file mode 100644 index 0000000..b246303 --- /dev/null +++ b/src/ntops/kernels/histc.py @@ -0,0 +1,133 @@ + +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed.language import libdevice +from ninetoothed import Tensor + + +def arrangement(*tensors, block_size): + # input, output, min, max = tensors + input, output, min_val, max_val, num_bins_pow2 = tensors + + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, ) + # output: (bins, ) + + input_tiled = input.flatten().tile((block_size, )) # (N // block_size), dtype=(block_size, ) + + output_expand = output.unsqueeze(0).expand((input_tiled.shape[0], -1)) # (N // block_size, bins) + output_tiled = output_expand.tile((1, -1)).squeeze(1) # (N // block_size, ), dtype=(1, bins) + output_tiled.dtype = output_tiled.dtype.squeeze(0) # dtype=(bins, ) + + return input_tiled, output_tiled, min_val, max_val, num_bins_pow2 + +def application_manual_histogram(input, output, min_val, max_val, num_bins_pow2): + """手动实现直方图计算。 + + 摩尔线程 GPU 内置的 histogram 函数不能正确计算柱状图, + 因此使用 ntl.arange 和 ntl.where 手动实现。 + """ + # input: (block_size,) + # output: (bins,) + n_out_bins = output.shape[0] + + # 只需要 [min_val, max_val] + mask = (input >= min_val) & (input <= max_val) + + # 标准化为 [0, n_out_bins) + input_scaled = (input - min_val) / (max_val - min_val) * n_out_bins + + # histogram 需要整数 bin 索引 + input_indices = ntl.cast(input_scaled, ntl.int32) + + # max_val 应该该落在最后一个 bin 中 + input_indices = ntl.minimum(input_indices, n_out_bins - 1) + + # 将超出范围的索引设为 -1,使其不会被计入直方图 + input_indices = ntl.where(mask, input_indices, -1) + + # 初始化直方图张量 + local_hist = ntl.zeros((num_bins_pow2,), dtype=output.dtype) + + # 逐 bin 计数:对每个 bin,用 where 统计匹配的元素个数 + # 由于摩尔线程不支持动态索引 histogram,因此只能手动实现 + for bin_idx in range(num_bins_pow2): + bin_idx_tensor = ntl.cast(bin_idx, ntl.int32) + match_mask = (input_indices == bin_idx_tensor) + count = ntl.sum(match_mask.to(output.dtype)) + idx = ntl.arange(0, num_bins_pow2) + update_mask = (idx == bin_idx_tensor) + local_hist = ntl.where(update_mask, count, local_hist) + + # 只需要前 n_out_bins 个 bin + valid_mask = ntl.arange(0, num_bins_pow2) < n_out_bins + local_hist = local_hist.to(output.dtype) + ntl.atomic_add(output.data_ptr() + output.offsets(), + local_hist, + mask=valid_mask) + + +def application_builtin_histogram(input, output, min_val, max_val, num_bins_pow2): + # input: (block_size,) + # output: (bins,) + n_out_bins = output.shape[0] + + # 只需要 [min_val, max_val] + mask = (input >= min_val) & (input <= max_val) + + # 标准化为 [0, n_out_bins) + input_scaled = (input - min_val) / (max_val - min_val) * n_out_bins + + # histogram 需要整数 bin 索引 + input_indices = ntl.cast(input_scaled, ntl.int32) + + # max_val 应该该落在最后一个 bin 中 + input_indices = ntl.minimum(input_indices, n_out_bins - 1) + + # 将超出范围的索引设为 -1,使其不会被计入直方图 + # 因为在 triton 3.5.0 版本才引入的 masked histogram + input_indices = ntl.where(mask, input_indices, -1) + + local_hist = ntl.histogram(input_indices, + num_bins=num_bins_pow2) # shape: (num_bins_pow2,) + + # 只需要前 n_out_bins 个 bin + valid_mask = ntl.arange(0, num_bins_pow2) < n_out_bins + local_hist = local_hist.to(output.dtype) + ntl.atomic_add(output.data_ptr() + output.offsets(), + local_hist, + mask=valid_mask) + + + +def premake_builtin(dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(1, dtype=dtype, other=float("inf"), shape_options={"constexpr": True}), # input + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output + Tensor(0, dtype=dtype), # min + Tensor(0, dtype=dtype), # max + Tensor(0, dtype=int, constexpr=True), # num_bins_pow2 + ) + + return arrangement_, application_builtin_histogram, tensors + + +def premake_manual(dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(1, dtype=dtype, other=float("inf"), shape_options={"constexpr": True}), # input + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output + Tensor(0, dtype=dtype), # min + Tensor(0, dtype=dtype), # max + Tensor(0, dtype=int, constexpr=True), # num_bins_pow2 + ) + + return arrangement_, application_manual_histogram, tensors + diff --git a/src/ntops/kernels/log10.py b/src/ntops/kernels/log10.py new file mode 100644 index 0000000..8a92d60 --- /dev/null +++ b/src/ntops/kernels/log10.py @@ -0,0 +1,21 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + if input.dtype == ntl.float16: + output = ntl.log(ntl.cast(input, ntl.float32)) * 0.4342944819032518 + else: + output = ntl.log(input) * 0.4342944819032518 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/log1p.py b/src/ntops/kernels/log1p.py new file mode 100644 index 0000000..727797b --- /dev/null +++ b/src/ntops/kernels/log1p.py @@ -0,0 +1,18 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = ntl.log(ntl.cast(input, ntl.float32) + 1) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..6e25657 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -36,6 +36,11 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.avg_pool3d import avg_pool3d +from ntops.torch.histc import histc +from ntops.torch.log10 import log10 +from ntops.torch.log1p import log1p +from ntops.torch.dot import dot __all__ = [ "abs", @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "avg_pool3d", + "histc", + "log10", + "log1p", + "dot", ] diff --git a/src/ntops/torch/avg_pool3d.py b/src/ntops/torch/avg_pool3d.py new file mode 100644 index 0000000..ec10f77 --- /dev/null +++ b/src/ntops/torch/avg_pool3d.py @@ -0,0 +1,54 @@ +import math +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def avg_pool3d(input, kernel_size: int | tuple[int, int, int], stride: None | int | tuple[int, int, int] = None, ceil_mode=False): + assert input.ndim == 5 or input.ndim == 4, "Input tensor must be 4-dimensional (N, C, D_in, H_in, W_in) or 3-dimensional (C, D_in, H_in, W_in)" + + if input.ndim == 4: + input = input.unsqueeze(0) # 添加 batch 维度 + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride, stride) + + if stride is None: + stride = kernel_size + + # 计算输出长度 + N, C, D_in, H_in, W_in = input.shape + if ceil_mode: + D_out = math.ceil((D_in - kernel_size[0]) / stride[0] + 1) + H_out = math.ceil((H_in - kernel_size[1]) / stride[1] + 1) + W_out = math.ceil((W_in - kernel_size[2]) / stride[2] + 1) + else: + D_out = math.floor((D_in - kernel_size[0]) / stride[0] + 1) + H_out = math.floor((H_in - kernel_size[1]) / stride[1] + 1) + W_out = math.floor((W_in - kernel_size[2]) / stride[2] + 1) + + output_shape = (N, C, D_out, H_out, W_out) + + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + block_size = 256 + kernel = _cached_make( + ntops.kernels.avg_pool3d.premake, + input.ndim, + kernel_size[0], + kernel_size[1], + kernel_size[2], + stride[0], + stride[1], + stride[2], + block_size=block_size, + ceil_mode=ceil_mode, + dtype=input.dtype + ) + kernel(input, output, kernel_size[0] * kernel_size[1] * kernel_size[2]) + + return output diff --git a/src/ntops/torch/dot.py b/src/ntops/torch/dot.py new file mode 100644 index 0000000..ab7fbcb --- /dev/null +++ b/src/ntops/torch/dot.py @@ -0,0 +1,33 @@ +import torch + +import ntops +import math +from ntops.torch.utils import _cached_make + + +def dot(input, other, *, out=None): + assert input.ndim == 1 and other.ndim == 1 + + if out is None: + out = torch.empty((1, ), dtype=input.dtype, device=input.device) + + input_numel = input.numel() + if input_numel <= 4096: + block_size = 1 << (input_numel - 1).bit_length() + kernel = _cached_make(ntops.kernels.dot.premake_dot_full, dtype=input.dtype, block_size=block_size) + kernel(input, other, out) + out = out.view(()) + else: + sqrt_n = math.isqrt(input_numel) + block_size = 1 << (sqrt_n - 1).bit_length() + temp_out = torch.empty(((input_numel // block_size), ), dtype=input.dtype, device=input.device) + + kernel1 = _cached_make(ntops.kernels.dot.premake_dot_divide, dtype=input.dtype, block_size=block_size) + kernel1(input, other, temp_out) + + kernel2 = _cached_make(ntops.kernels.dot.premake_dot_conquer, dtype=input.dtype, block_size=block_size) + kernel2(temp_out, out) + + out = out.view(()) + + return out diff --git a/src/ntops/torch/histc.py b/src/ntops/torch/histc.py new file mode 100644 index 0000000..f259872 --- /dev/null +++ b/src/ntops/torch/histc.py @@ -0,0 +1,31 @@ + +import torch + +import ntops +from ntops.torch.utils import _cached_make +import builtins + +def histc(input, bins=100, min=None, max=None, is_moore=False): + if min is None: + min = torch.min(input).item() + + if max is None: + max = torch.max(input).item() + + # block_size = builtins.min(1024, 1 << (input.shape[0] - 1).bit_length()) + # block_size = builtins.max(32, block_size) + block_size = 256 + # 初始化输出为零,因为我们会累加直方图计数 + num_bins_pow2 = 1 << (bins - 1).bit_length() # 计算大于等于 bins 的最小 2 的幂次方 + out = torch.empty((bins,), dtype=input.dtype, device=input.device) + out = torch.nn.init.zeros_(out) + + if is_moore: + kernel = _cached_make(ntops.kernels.histc.premake_manual, input.dtype, block_size=block_size) + else: + kernel = _cached_make(ntops.kernels.histc.premake_builtin, input.dtype, block_size=block_size) + + kernel(input, out, min, max, num_bins_pow2) + + return out + diff --git a/src/ntops/torch/log10.py b/src/ntops/torch/log10.py new file mode 100644 index 0000000..f08fdd8 --- /dev/null +++ b/src/ntops/torch/log10.py @@ -0,0 +1,16 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def log10(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + block_size = 1024 + kernel = _cached_make(ntops.kernels.log10.premake, input.ndim, block_size) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/log1p.py b/src/ntops/torch/log1p.py new file mode 100644 index 0000000..2ceacf3 --- /dev/null +++ b/src/ntops/torch/log1p.py @@ -0,0 +1,16 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def log1p(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + block_size = 1024 + kernel = _cached_make(ntops.kernels.log1p.premake, input.ndim, block_size) + + kernel(input, out) + + return out From f0298030712a5407325a020d59cb9eea2bc76fc9 Mon Sep 17 00:00:00 2001 From: greenhandhand <781740145@qq.com> Date: Sun, 14 Dec 2025 16:20:40 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20pytest=20=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=EF=BC=8C=E4=BF=AE=E5=A4=8D=20avg=5Fpool3d=20=E4=B8=AD?= =?UTF-8?q?=E5=85=B3=E4=BA=8E=20ceil=5Fmode=20=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ntops/torch/avg_pool3d.py | 52 ++++++++++++++++++++++++++++++----- tests/test_avg_pool3d.py | 36 ++++++++++++++++++++++++ tests/test_dot.py | 25 +++++++++++++++++ tests/test_histc.py | 22 +++++++++++++++ tests/test_log10.py | 17 ++++++++++++ tests/test_log1p.py | 17 ++++++++++++ 6 files changed, 162 insertions(+), 7 deletions(-) create mode 100644 tests/test_avg_pool3d.py create mode 100644 tests/test_dot.py create mode 100644 tests/test_histc.py create mode 100644 tests/test_log10.py create mode 100644 tests/test_log1p.py diff --git a/src/ntops/torch/avg_pool3d.py b/src/ntops/torch/avg_pool3d.py index ec10f77..1b06140 100644 --- a/src/ntops/torch/avg_pool3d.py +++ b/src/ntops/torch/avg_pool3d.py @@ -4,6 +4,36 @@ import ntops from ntops.torch.utils import _cached_make +def _effective_counts(input_shape, kernel_size, stride): + """Compute number of valid elements per output position when padding is implicit. + + For ceil_mode=True the last window in each dimension may be smaller than the + kernel. We precompute the effective element count so we can rescale the + zero-padded average produced by the kernel back to `count_include_pad=False` + semantics to match PyTorch. + """ + + N, C, D_in, H_in, W_in = input_shape + kd, kh, kw = kernel_size + sd, sh, sw = stride + + D_out = math.ceil((D_in - kd) / sd + 1) + H_out = math.ceil((H_in - kh) / sh + 1) + W_out = math.ceil((W_in - kw) / sw + 1) + + d_range = torch.arange(D_out, device="cuda") + h_range = torch.arange(H_out, device="cuda") + w_range = torch.arange(W_out, device="cuda") + + kd_eff = torch.clamp(D_in - d_range * sd, max=kd).clamp_min(0) + kh_eff = torch.clamp(H_in - h_range * sh, max=kh).clamp_min(0) + kw_eff = torch.clamp(W_in - w_range * sw, max=kw).clamp_min(0) + + counts = kd_eff[:, None, None] * kh_eff[None, :, None] * kw_eff[None, None, :] + counts = counts.view(1, 1, D_out, H_out, W_out).expand(N, C, -1, -1, -1) + + return counts + def avg_pool3d(input, kernel_size: int | tuple[int, int, int], stride: None | int | tuple[int, int, int] = None, ceil_mode=False): assert input.ndim == 5 or input.ndim == 4, "Input tensor must be 4-dimensional (N, C, D_in, H_in, W_in) or 3-dimensional (C, D_in, H_in, W_in)" @@ -30,16 +60,16 @@ def avg_pool3d(input, kernel_size: int | tuple[int, int, int], stride: None | in D_out = math.floor((D_in - kernel_size[0]) / stride[0] + 1) H_out = math.floor((H_in - kernel_size[1]) / stride[1] + 1) W_out = math.floor((W_in - kernel_size[2]) / stride[2] + 1) - + output_shape = (N, C, D_out, H_out, W_out) output = torch.empty(output_shape, dtype=input.dtype, device=input.device) - block_size = 256 + block_size = 1024 kernel = _cached_make( - ntops.kernels.avg_pool3d.premake, - input.ndim, - kernel_size[0], + ntops.kernels.avg_pool3d.premake, + input.ndim, + kernel_size[0], kernel_size[1], kernel_size[2], stride[0], @@ -49,6 +79,14 @@ def avg_pool3d(input, kernel_size: int | tuple[int, int, int], stride: None | in ceil_mode=ceil_mode, dtype=input.dtype ) - kernel(input, output, kernel_size[0] * kernel_size[1] * kernel_size[2]) - + kernel_volume = kernel_size[0] * kernel_size[1] * kernel_size[2] + kernel(input, output, kernel_volume) + + if ceil_mode: + counts = _effective_counts((N, C, D_in, H_in, W_in), kernel_size, stride) + counts = counts.to(dtype=output.dtype, device=output.device) + torch.mul(output, kernel_volume, out=output) + torch.div(output, counts, out=output) + # output.mul_(kernel_volume).div_(counts) + return output diff --git a/tests/test_avg_pool3d.py b/tests/test_avg_pool3d.py new file mode 100644 index 0000000..c2dff05 --- /dev/null +++ b/tests/test_avg_pool3d.py @@ -0,0 +1,36 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize( + "input_shape, kernel_size, stride, ceil_mode", + [ + ((3, 4, 5, 6), 2, 2, False), + ((2, 3, 5, 6, 7), (3, 2, 2), None, True), + ], +) +@pytest.mark.parametrize( + "dtype, rtol, atol", + [ + (torch.float16, 0.01, 0.01), + (torch.float32, 0.001, 0.001), + ], +) +def test_avg_pool3d(input_shape, kernel_size, stride, ceil_mode, dtype, rtol, atol): + device = "cuda" + input = torch.randn(input_shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.avg_pool3d( + input, kernel_size, stride=stride, ceil_mode=ceil_mode + ) + + reference_input = input if input.ndim == 5 else input.unsqueeze(0) + reference_output = torch.nn.functional.avg_pool3d( + reference_input, kernel_size, stride=stride, ceil_mode=ceil_mode + ) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_dot.py b/tests/test_dot.py new file mode 100644 index 0000000..dbbf883 --- /dev/null +++ b/tests/test_dot.py @@ -0,0 +1,25 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("length", [1024, 8192]) +@pytest.mark.parametrize( + "dtype, rtol, atol", + [ + (torch.float16, 0.01, 0.01), + (torch.float32, 0.001, 0.001), + ], +) +def test_dot(length, dtype, rtol, atol): + device = "cuda" + input = torch.randn((length,), device=device, dtype=dtype) + other = torch.randn((length,), device=device, dtype=dtype) + + ninetoothed_output = ntops.torch.dot(input, other) + reference_output = torch.dot(input, other) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_histc.py b/tests/test_histc.py new file mode 100644 index 0000000..9bd231a --- /dev/null +++ b/tests/test_histc.py @@ -0,0 +1,22 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("bins", [10, 50]) +@pytest.mark.parametrize("is_moore", [False, True]) +def test_histc(bins, is_moore): + dtype = torch.float32 + device = "cuda" + input = torch.randn((2048,), dtype=dtype, device=device) * 5 - 2 + min_val, max_val = -5.0, 5.0 + + ninetoothed_output = ntops.torch.histc( + input, bins=bins, min=min_val, max=max_val, is_moore=is_moore + ) + reference_output = torch.histc(input, bins=bins, min=min_val, max=max_val) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=0, atol=1e-3) diff --git a/tests/test_log10.py b/tests/test_log10.py new file mode 100644 index 0000000..2ca45bd --- /dev/null +++ b/tests/test_log10.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_log10(shape, dtype, device, rtol, atol): + input = torch.rand(shape, dtype=dtype, device=device) + 0.1 + + ninetoothed_output = ntops.torch.log10(input) + reference_output = torch.log10(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_log1p.py b/tests/test_log1p.py new file mode 100644 index 0000000..85051d9 --- /dev/null +++ b/tests/test_log1p.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_log1p(shape, dtype, device, rtol, atol): + input = torch.rand(shape, dtype=dtype, device=device) - 0.5 + + ninetoothed_output = ntops.torch.log1p(input) + reference_output = torch.log1p(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)