diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index f4cebc5..72628b9 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -93,6 +93,7 @@ echo '::group::Install third party dependencies prior to extension-cpp install' # - It happily pulls in pre-releases, which can lead to more problems down the line. # `pip` does not unless explicitly told to do so. # Thus, we use `easy_install` to extract the third-party dependencies here and install them upfront with `pip`. +pushd extension_cpp python setup.py egg_info # The requires.txt cannot be used with `pip install -r` directly. The requirements are listed at the top and the # optional dependencies come in non-standard syntax after a blank line. Thus, we just extract the header. @@ -100,8 +101,15 @@ sed -e '/^$/,$d' *.egg-info/requires.txt | tee requirements.txt pip install --progress-bar=off -r requirements.txt echo '::endgroup::' -echo '::group::Install extension-cpp' -python setup.py develop +echo '::group::Install extension_cpp (standard ATen API)' +pip install -e . --no-build-isolation +popd +echo '::endgroup::' + +echo '::group::Install extension_cpp_stable (stable ABI)' +pushd extension_cpp_stable +pip install -e . --no-build-isolation +popd echo '::endgroup::' echo '::group::Collect environment information' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1565937..5372bf9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,7 @@ jobs: - python-version: 3.13 runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda - gpu-arch-version: "12.4" + gpu-arch-version: "12.9" fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: diff --git a/README.md b/README.md index d523814..1aac5d5 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,42 @@ # C++/CUDA Extensions in PyTorch -An example of writing a C++/CUDA extension for PyTorch. See -[here](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for the accompanying tutorial. -This repo demonstrates how to write an example `extension_cpp.ops.mymuladd` -custom op that has both custom CPU and CUDA kernels. +This repository contains two example C++/CUDA extensions for PyTorch: -The examples in this repo work with PyTorch 2.4+. +1. **extension_cpp** - Uses the standard ATen/LibTorch API +2. **extension_cpp_stable** - Uses the [LibTorch Stable ABI](https://pytorch.org/docs/main/notes/libtorch_stable_abi.html) -To build: +Both extensions demonstrate how to write an example `mymuladd` custom op that has both +custom CPU and CUDA kernels. + +## extension_cpp (Standard ATen API) + +Uses the full ATen/LibTorch API. This is the traditional way of writing PyTorch extensions. +See [this tutorial](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html) for more details. + +## extension_cpp_stable (Stable ABI) + +Uses the LibTorch Stable ABI to ensure that the extension built can be run with any version +of PyTorch >= 2.10.0, without needing to recompile for each PyTorch version. + +The `extension_cpp_stable` examples require PyTorch 2.10+. + +## Building + +To build extension_cpp (standard API): +``` +cd extension_cpp +pip install --no-build-isolation -e . ``` + +To build extension_cpp_stable (stable ABI): +``` +cd extension_cpp_stable pip install --no-build-isolation -e . ``` -To test: +## Testing + +To test both extensions: ``` python test/test_extension.py ``` diff --git a/extension_cpp/__init__.py b/extension_cpp/extension_cpp/__init__.py similarity index 100% rename from extension_cpp/__init__.py rename to extension_cpp/extension_cpp/__init__.py diff --git a/extension_cpp/csrc/cuda/muladd.cu b/extension_cpp/extension_cpp/csrc/cuda/muladd.cu similarity index 100% rename from extension_cpp/csrc/cuda/muladd.cu rename to extension_cpp/extension_cpp/csrc/cuda/muladd.cu diff --git a/extension_cpp/csrc/muladd.cpp b/extension_cpp/extension_cpp/csrc/muladd.cpp similarity index 100% rename from extension_cpp/csrc/muladd.cpp rename to extension_cpp/extension_cpp/csrc/muladd.cpp diff --git a/extension_cpp/ops.py b/extension_cpp/extension_cpp/ops.py similarity index 100% rename from extension_cpp/ops.py rename to extension_cpp/extension_cpp/ops.py diff --git a/pyproject.toml b/extension_cpp/pyproject.toml similarity index 100% rename from pyproject.toml rename to extension_cpp/pyproject.toml diff --git a/requirements.txt b/extension_cpp/requirements.txt similarity index 100% rename from requirements.txt rename to extension_cpp/requirements.txt diff --git a/setup.py b/extension_cpp/setup.py similarity index 95% rename from setup.py rename to extension_cpp/setup.py index 0dde1e4..33a2a99 100644 --- a/setup.py +++ b/extension_cpp/setup.py @@ -79,7 +79,9 @@ def get_extensions(): ext_modules=get_extensions(), install_requires=["torch"], description="Example of PyTorch C++ and CUDA extensions", - long_description=open("README.md").read(), + long_description=open( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "README.md") + ).read(), long_description_content_type="text/markdown", url="https://github.com/pytorch/extension-cpp", cmdclass={"build_ext": BuildExtension}, diff --git a/extension_cpp_stable/extension_cpp_stable/__init__.py b/extension_cpp_stable/extension_cpp_stable/__init__.py new file mode 100644 index 0000000..fb8aad8 --- /dev/null +++ b/extension_cpp_stable/extension_cpp_stable/__init__.py @@ -0,0 +1 @@ +from . import _C, ops # noqa: F401 diff --git a/extension_cpp_stable/extension_cpp_stable/csrc/cuda/muladd.cu b/extension_cpp_stable/extension_cpp_stable/csrc/cuda/muladd.cu new file mode 100644 index 0000000..5b2a961 --- /dev/null +++ b/extension_cpp_stable/extension_cpp_stable/csrc/cuda/muladd.cu @@ -0,0 +1,143 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace extension_cpp_stable { + +__global__ void muladd_kernel(int numel, const float *a, const float *b, + float c, float *result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) + result[idx] = a[idx] * b[idx] + c; +} + +torch::stable::Tensor mymuladd_cuda(const torch::stable::Tensor &a, + const torch::stable::Tensor &b, double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float *a_ptr = a_contig.const_data_ptr(); + const float *b_ptr = b_contig.const_data_ptr(); + float *result_ptr = result.mutable_data_ptr(); + + int numel = a_contig.numel(); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void *stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + + muladd_kernel<<<(numel + 255) / 256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, + result_ptr); + return result; +} + +__global__ void mul_kernel(int numel, const float *a, const float *b, + float *result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) + result[idx] = a[idx] * b[idx]; +} + +torch::stable::Tensor mymul_cuda(const torch::stable::Tensor &a, + const torch::stable::Tensor &b) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float *a_ptr = a_contig.const_data_ptr(); + const float *b_ptr = b_contig.const_data_ptr(); + float *result_ptr = result.mutable_data_ptr(); + + int numel = a_contig.numel(); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void *stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + + mul_kernel<<<(numel + 255) / 256, 256, 0, stream>>>(numel, a_ptr, b_ptr, + result_ptr); + return result; +} + +__global__ void add_kernel(int numel, const float *a, const float *b, + float *result) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) + result[idx] = a[idx] + b[idx]; +} + +// An example of an operator that mutates one of its inputs. +void myadd_out_cuda(const torch::stable::Tensor &a, + const torch::stable::Tensor &b, + torch::stable::Tensor &out) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(b.sizes().equals(out.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CUDA); + STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CUDA); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + + const float *a_ptr = a_contig.const_data_ptr(); + const float *b_ptr = b_contig.const_data_ptr(); + float *result_ptr = out.mutable_data_ptr(); + + int numel = a_contig.numel(); + + // For now, we rely on the raw shim API to get the current CUDA stream. + // This will be improved in a future release. + // When using a raw shim API, we need to use TORCH_ERROR_CODE_CHECK to + // check the error code and throw an appropriate runtime_error otherwise. + void *stream_ptr = nullptr; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(a.get_device_index(), &stream_ptr)); + cudaStream_t stream = static_cast(stream_ptr); + + add_kernel<<<(numel + 255) / 256, 256, 0, stream>>>(numel, a_ptr, b_ptr, + result_ptr); +} + +// Registers CUDA implementations for mymuladd, mymul, myadd_out +STABLE_TORCH_LIBRARY_IMPL(extension_cpp_stable, CUDA, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cuda)); + m.impl("mymul", TORCH_BOX(&mymul_cuda)); + m.impl("myadd_out", TORCH_BOX(&myadd_out_cuda)); +} + +} // namespace extension_cpp_stable diff --git a/extension_cpp_stable/extension_cpp_stable/csrc/muladd.cpp b/extension_cpp_stable/extension_cpp_stable/csrc/muladd.cpp new file mode 100644 index 0000000..c843c58 --- /dev/null +++ b/extension_cpp_stable/extension_cpp_stable/csrc/muladd.cpp @@ -0,0 +1,118 @@ +#include + +#include +#include +#include +#include +#include + +extern "C" { + /* Creates a dummy empty _C module that can be imported from Python. + The import from Python will load the .so consisting of this file + in this extension, so that the STABLE_TORCH_LIBRARY static initializers + below are run. */ + PyObject* PyInit__C(void) + { + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + NULL, /* methods */ + }; + return PyModule_Create(&module_def); + } +} + +namespace extension_cpp_stable { + +torch::stable::Tensor mymuladd_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + double c) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; + } + return result; +} + +torch::stable::Tensor mymul_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + torch::stable::Tensor result = torch::stable::empty_like(a_contig); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = result.mutable_data_ptr(); + + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i]; + } + return result; +} + +// An example of an operator that mutates one of its inputs. +void myadd_out_cpu( + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + torch::stable::Tensor& out) { + STD_TORCH_CHECK(a.sizes().equals(b.sizes())); + STD_TORCH_CHECK(b.sizes().equals(out.sizes())); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(out.is_contiguous()); + STD_TORCH_CHECK(a.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(b.device().type() == torch::headeronly::DeviceType::CPU); + STD_TORCH_CHECK(out.device().type() == torch::headeronly::DeviceType::CPU); + + torch::stable::Tensor a_contig = torch::stable::contiguous(a); + torch::stable::Tensor b_contig = torch::stable::contiguous(b); + + const float* a_ptr = a_contig.const_data_ptr(); + const float* b_ptr = b_contig.const_data_ptr(); + float* result_ptr = out.mutable_data_ptr(); + + for (int64_t i = 0; i < out.numel(); i++) { + result_ptr[i] = a_ptr[i] + b_ptr[i]; + } +} + +// Defines the operators +STABLE_TORCH_LIBRARY(extension_cpp_stable, m) { + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + m.def("mymul(Tensor a, Tensor b) -> Tensor"); + m.def("myadd_out(Tensor a, Tensor b, Tensor(a!) out) -> ()"); +} + +// Registers CPU implementations for mymuladd, mymul, myadd_out +STABLE_TORCH_LIBRARY_IMPL(extension_cpp_stable, CPU, m) { + m.impl("mymuladd", TORCH_BOX(&mymuladd_cpu)); + m.impl("mymul", TORCH_BOX(&mymul_cpu)); + m.impl("myadd_out", TORCH_BOX(&myadd_out_cpu)); +} + +} diff --git a/extension_cpp_stable/extension_cpp_stable/ops.py b/extension_cpp_stable/extension_cpp_stable/ops.py new file mode 100644 index 0000000..3280fc4 --- /dev/null +++ b/extension_cpp_stable/extension_cpp_stable/ops.py @@ -0,0 +1,63 @@ +import torch +from torch import Tensor + +__all__ = ["mymuladd", "myadd_out"] + + +def mymuladd(a: Tensor, b: Tensor, c: float) -> Tensor: + """Performs a * b + c in an efficient fused kernel""" + return torch.ops.extension_cpp_stable.mymuladd.default(a, b, c) + + +# Registers a FakeTensor kernel (aka "meta kernel", "abstract impl") +# that describes what the properties of the output Tensor are given +# the properties of the input Tensor. The FakeTensor kernel is necessary +# for the op to work performantly with torch.compile. +@torch.library.register_fake("extension_cpp_stable::mymuladd") +def _(a, b, c): + torch._check(a.shape == b.shape) + torch._check(a.dtype == torch.float) + torch._check(b.dtype == torch.float) + torch._check(a.device == b.device) + return torch.empty_like(a) + + +def _backward(ctx, grad): + a, b = ctx.saved_tensors + grad_a, grad_b = None, None + if ctx.needs_input_grad[0]: + grad_a = torch.ops.extension_cpp_stable.mymul.default(grad, b) + if ctx.needs_input_grad[1]: + grad_b = torch.ops.extension_cpp_stable.mymul.default(grad, a) + return grad_a, grad_b, None + + +def _setup_context(ctx, inputs, output): + a, b, c = inputs + saved_a, saved_b = None, None + if ctx.needs_input_grad[0]: + saved_b = b + if ctx.needs_input_grad[1]: + saved_a = a + ctx.save_for_backward(saved_a, saved_b) + + +# This adds training support for the operator. You must provide us +# the backward formula for the operator and a `setup_context` function +# to save values to be used in the backward. +torch.library.register_autograd( + "extension_cpp_stable::mymuladd", _backward, setup_context=_setup_context) + + +@torch.library.register_fake("extension_cpp_stable::mymul") +def _(a, b): + torch._check(a.shape == b.shape) + torch._check(a.dtype == torch.float) + torch._check(b.dtype == torch.float) + torch._check(a.device == b.device) + return torch.empty_like(a) + + +def myadd_out(a: Tensor, b: Tensor, out: Tensor) -> None: + """Writes a + b into out""" + torch.ops.extension_cpp_stable.myadd_out.default(a, b, out) diff --git a/extension_cpp_stable/pyproject.toml b/extension_cpp_stable/pyproject.toml new file mode 100644 index 0000000..ffef670 --- /dev/null +++ b/extension_cpp_stable/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools", + "torch>=2.10.0", +] +build-backend = "setuptools.build_meta" diff --git a/extension_cpp_stable/requirements.txt b/extension_cpp_stable/requirements.txt new file mode 100644 index 0000000..af3149e --- /dev/null +++ b/extension_cpp_stable/requirements.txt @@ -0,0 +1,2 @@ +torch +numpy diff --git a/extension_cpp_stable/setup.py b/extension_cpp_stable/setup.py new file mode 100644 index 0000000..ca19488 --- /dev/null +++ b/extension_cpp_stable/setup.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import glob + +from setuptools import find_packages, setup + +from torch.utils.cpp_extension import ( + CppExtension, + CUDAExtension, + BuildExtension, + CUDA_HOME, +) + +library_name = "extension_cpp_stable" + + +if torch.__version__ >= "2.6.0": + py_limited_api = True +else: + py_limited_api = False + + +def get_extensions(): + debug_mode = os.getenv("DEBUG", "0") == "1" + use_cuda = os.getenv("USE_CUDA", "1") == "1" + if debug_mode: + print("Compiling in debug mode") + + use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None + extension = CUDAExtension if use_cuda else CppExtension + + extra_link_args = [] + extra_compile_args = { + "cxx": [ + "-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + "-DPy_LIMITED_API=0x03090000", + # define TORCH_TARGET_VERSION with min version 2.10 to expose only the + # stable API subset from torch + # Format: [MAJ 1 byte][MIN 1 byte][PATCH 1 byte][ABI TAG 5 bytes] + # 2.10.0 = 0x020A000000000000 + "-DTORCH_TARGET_VERSION=0x020a000000000000", + ], + "nvcc": [ + "-O3" if not debug_mode else "-O0", + # NVCC also needs TORCH_TARGET_VERSION for stable ABI in CUDA code + "-DTORCH_TARGET_VERSION=0x020a000000000000", + # USE_CUDA is currently needed for aoti_torch_get_current_cuda_stream + # declaration in shim.h. This will be improved in a future release. + "-DUSE_CUDA", + ], + } + if debug_mode: + extra_compile_args["cxx"].append("-g") + extra_compile_args["nvcc"].append("-g") + extra_link_args.extend(["-O0", "-g"]) + + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, library_name, "csrc") + sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) + + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) + + if use_cuda: + sources += cuda_sources + + ext_modules = [ + extension( + f"{library_name}._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + py_limited_api=py_limited_api, + ) + ] + + return ext_modules + + +setup( + name=library_name, + version="0.0.1", + packages=find_packages(), + ext_modules=get_extensions(), + install_requires=["torch>=2.10.0"], + description="Example of PyTorch C++ and CUDA extensions using Stable ABI", + long_description=open( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "README.md") + ).read(), + long_description_content_type="text/markdown", + url="https://github.com/pytorch/extension-cpp", + cmdclass={"build_ext": BuildExtension}, + options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, +) diff --git a/test/test_extension.py b/test/test_extension.py index f17d7da..96cfdd5 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -1,8 +1,13 @@ import torch -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, + parametrize, + instantiate_parametrized_tests, +) from torch.testing._internal.optests import opcheck import unittest -import extension_cpp + from torch import Tensor from typing import Tuple import torch.nn.functional as F @@ -13,6 +18,15 @@ def reference_muladd(a, b, c): return a * b + c +def get_extension(ext_name): + if ext_name == "extension_cpp": + import extension_cpp + return extension_cpp + else: + import extension_cpp_stable + return extension_cpp_stable + + class TestMyMulAdd(TestCase): def sample_inputs(self, device, *, requires_grad=False): def make_tensor(*size): @@ -28,25 +42,27 @@ def make_nondiff_tensor(*size): [make_nondiff_tensor(2, 3), make_tensor(2, 3), -0.3], ] - def _test_correctness(self, device): + def _test_correctness(self, device, ext): samples = self.sample_inputs(device) for args in samples: - result = extension_cpp.ops.mymuladd(*args) + result = ext.ops.mymuladd(*args) expected = reference_muladd(*args) torch.testing.assert_close(result, expected) - def test_correctness_cpu(self): - self._test_correctness("cpu") + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) + def test_correctness_cpu(self, ext_name): + self._test_correctness("cpu", get_extension(ext_name)) + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_correctness_cuda(self): - self._test_correctness("cuda") + def test_correctness_cuda(self, ext_name): + self._test_correctness("cuda", get_extension(ext_name)) - def _test_gradients(self, device): + def _test_gradients(self, device, ext): samples = self.sample_inputs(device, requires_grad=True) for args in samples: diff_tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] - out = extension_cpp.ops.mymuladd(*args) + out = ext.ops.mymuladd(*args) grad_out = torch.randn_like(out) result = torch.autograd.grad(out, diff_tensors, grad_out) @@ -55,26 +71,30 @@ def _test_gradients(self, device): torch.testing.assert_close(result, expected) - def test_gradients_cpu(self): - self._test_gradients("cpu") + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) + def test_gradients_cpu(self, ext_name): + self._test_gradients("cpu", get_extension(ext_name)) + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_gradients_cuda(self): - self._test_gradients("cuda") + def test_gradients_cuda(self, ext_name): + self._test_gradients("cuda", get_extension(ext_name)) - def _opcheck(self, device): - # Use opcheck to check for incorrect usage of operator registration APIs + def _opcheck(self, device, ext_name): samples = self.sample_inputs(device, requires_grad=True) samples.extend(self.sample_inputs(device, requires_grad=False)) + op = getattr(torch.ops, ext_name).mymuladd.default for args in samples: - opcheck(torch.ops.extension_cpp.mymuladd.default, args) + opcheck(op, args) - def test_opcheck_cpu(self): - self._opcheck("cpu") + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) + def test_opcheck_cpu(self, ext_name): + self._opcheck("cpu", ext_name) + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_opcheck_cuda(self): - self._opcheck("cuda") + def test_opcheck_cuda(self, ext_name): + self._opcheck("cuda", ext_name) class TestMyAddOut(TestCase): @@ -90,84 +110,100 @@ def make_nondiff_tensor(*size): [make_tensor(20), make_tensor(20), make_tensor(20)], ] - def _test_correctness(self, device): + def _test_correctness(self, device, ext): samples = self.sample_inputs(device) for args in samples: result = args[-1] - extension_cpp.ops.myadd_out(*args) + ext.ops.myadd_out(*args) expected = torch.add(*args[:2]) torch.testing.assert_close(result, expected) - def test_correctness_cpu(self): - self._test_correctness("cpu") + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) + def test_correctness_cpu(self, ext_name): + self._test_correctness("cpu", get_extension(ext_name)) + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_correctness_cuda(self): - self._test_correctness("cuda") + def test_correctness_cuda(self, ext_name): + self._test_correctness("cuda", get_extension(ext_name)) - def _opcheck(self, device): - # Use opcheck to check for incorrect usage of operator registration APIs + def _opcheck(self, device, ext_name): samples = self.sample_inputs(device, requires_grad=True) samples.extend(self.sample_inputs(device, requires_grad=False)) + op = getattr(torch.ops, ext_name).myadd_out.default for args in samples: - opcheck(torch.ops.extension_cpp.myadd_out.default, args) + opcheck(op, args) - def test_opcheck_cpu(self): - self._opcheck("cpu") + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) + def test_opcheck_cpu(self, ext_name): + self._opcheck("cpu", ext_name) + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_opcheck_cuda(self): - self._opcheck("cuda") + def test_opcheck_cuda(self, ext_name): + self._opcheck("cuda", ext_name) class TestTorchCompileStreamSync(TestCase): """Test for GitHub issue pytorch/pytorch#157363 - stream synchronization with torch.compile""" - + + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_compile_with_linear_layer(self): + def test_compile_with_linear_layer(self, ext_name): """Test custom CUDA kernels with nn.Linear + torch.compile (the original failing case)""" - + ext = get_extension(ext_name) + class Model(nn.Module): - def __init__(self, size): + def __init__(self, size, extension): super().__init__() - self.linear = nn.Linear(size, size, device="cuda", dtype=torch.float32) - + self.linear = nn.Linear( + size, size, device="cuda", dtype=torch.float32 + ) + self.ext = extension + def forward(self, x): - return extension_cpp.ops.mymuladd(self.linear(x), self.linear(x), 0.0) - + return self.ext.ops.mymuladd(self.linear(x), self.linear(x), 0.0) + # Test sizes that previously failed for size in [1000, 5000, 10000]: with self.subTest(size=size): torch.manual_seed(42) - model = Model(size) + model = Model(size, ext) x = torch.randn((1, size), device="cuda", dtype=torch.float32) - + with torch.no_grad(): expected = model(x) compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) actual = compiled_model(x) - + self.assertEqual(actual, expected) + @parametrize("ext_name", ["extension_cpp", "extension_cpp_stable"]) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") - def test_compile_custom_only(self): + def test_compile_custom_only(self, ext_name): """Test custom operations alone with torch.compile""" - + ext = get_extension(ext_name) + def model(x): - return extension_cpp.ops.mymuladd(x, x, 1.0) - + return ext.ops.mymuladd(x, x, 1.0) + for size in [1000, 5000, 10000]: with self.subTest(size=size): torch.manual_seed(42) x = torch.randn((size,), device="cuda", dtype=torch.float32) - + with torch.no_grad(): expected = model(x) compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) actual = compiled_model(x) - + self.assertEqual(actual, expected) +instantiate_parametrized_tests(TestMyMulAdd) +instantiate_parametrized_tests(TestMyAddOut) +instantiate_parametrized_tests(TestTorchCompileStreamSync) + + if __name__ == "__main__": - unittest.main() + run_tests()