diff --git a/.github/workflows/test_py_inductor.yml b/.github/workflows/test_py_inductor.yml new file mode 100644 index 00000000..adc4c20b --- /dev/null +++ b/.github/workflows/test_py_inductor.yml @@ -0,0 +1,26 @@ +name: Test Python Inductor + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Prepare PyTorch 2.0 nightly + run: pip3 install --pre torch==2.0.0.dev20230205 --index-url https://download.pytorch.org/whl/nightly/cpu + - name: Echo GCC version + run: gcc --version + - name: Install MATXScript Requirements + run: pip3 install -r python/requirements.txt + - name: Python Inductor Test + run: bash ci/run_py_inductor_test.sh diff --git a/ci/run_py_inductor_test.sh b/ci/run_py_inductor_test.sh new file mode 100644 index 00000000..14a68c0f --- /dev/null +++ b/ci/run_py_inductor_test.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -xue +set -o pipefail + +THIS_PATH=$(cd $(dirname "$0"); pwd) +ROOT_PATH=${THIS_PATH}/../ + +############################################################################### +# build all shared target +############################################################################### +cd "${ROOT_PATH}" || exit 1 +BUILD_TESTING=OFF BUILD_BENCHMARK=OFF CPPFLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" bash ci/build_lib.sh + +############################################################################### +# install requirements +############################################################################### +PYTHON_MODULE_PATH=${ROOT_PATH}/python +cd "${PYTHON_MODULE_PATH}" +pip3 install -r requirements.txt + +############################################################################### +# find all test script +############################################################################### +PYTHONPATH=${PYTHONPATH:-} +TEST_SCRIPT_PATH=${ROOT_PATH}/test/inductor +cd "${TEST_SCRIPT_PATH}" +# shellcheck disable=SC2045 +for script_file in $(ls test_*.py); do + echo "test script: ${script_file}" + PYTHONPATH="${ROOT_PATH}/python:${PYTHONPATH}" python3 "${script_file}" +done diff --git a/python/matx/__init__.py b/python/matx/__init__.py index 496710c1..63dddd4a 100644 --- a/python/matx/__init__.py +++ b/python/matx/__init__.py @@ -30,7 +30,6 @@ from . import vision from . import tools - # APIs __all__ = [ # functions @@ -41,6 +40,7 @@ "trace", "script", "script_embedded_class", + "inductor", "save", "load", "get_cflags", @@ -352,6 +352,23 @@ def script(compiling_obj, *args, backend=None, **kwargs): return toolchain.script(compiling_obj, *args, **kwargs) +def inductor(example_inputs, **kwargs): + """ + + Args: + example_inputs: any nested structure of torch.Tensor that passed into the kernel + **kwargs: other keyword arguments passed into toolchain.inductor + + Returns: a wrapper that compiles the compiling_obj into a JIT FUNCTION + + """ + + def inner_inductor(compiling_obj): + return toolchain.inductor(compiling_obj, example_inputs, **kwargs) + + return inner_inductor + + def script_embedded_class(code, is_path=False): return toolchain.script_embedded_class(code, is_path) diff --git a/python/matx/contrib/cc.py b/python/matx/contrib/cc.py index bd45d695..4bed4eeb 100644 --- a/python/matx/contrib/cc.py +++ b/python/matx/contrib/cc.py @@ -93,9 +93,16 @@ def find_sys_cc_path(): raise RuntimeError("win32 is not supported") elif sys.platform.startswith('darwin'): # maybe we can use clang++ - cc_bin = "g++" + # prioritized compiler defined in CXX + if 'CXX' in os.environ: + cc_bin = os.environ['CXX'] + else: + cc_bin = "g++" else: - cc_bin = "g++" + if 'CXX' in os.environ: + cc_bin = os.environ['CXX'] + else: + cc_bin = "g++" return cc_bin diff --git a/python/matx/pipeline/_register_conveter.py b/python/matx/pipeline/_register_conveter.py index 10df789f..7fca3493 100644 --- a/python/matx/pipeline/_register_conveter.py +++ b/python/matx/pipeline/_register_conveter.py @@ -17,6 +17,16 @@ # specific language governing permissions and limitations # under the License. + +try: + import torch + import torch.utils.dlpack + + HAS_TORCH = True +except: + HAS_TORCH = False + +import matx from .._ffi._selector import _set_fast_pipeline_object_converter from .._ffi._selector import _set_class_symbol from .symbol import BaseSymbol @@ -29,9 +39,13 @@ def _pipeline_object_converter(value): return value.native_op if isinstance(value, OpKernel): return value.native_op + if HAS_TORCH and isinstance(value, torch.Tensor): + return matx.array.from_dlpack(torch.utils.dlpack.to_dlpack(value)) return value _PipelineClasses = (JitObject, OpKernel,) +if HAS_TORCH: + _PipelineClasses += (torch.Tensor,) _set_fast_pipeline_object_converter(_PipelineClasses, _pipeline_object_converter) _set_class_symbol(BaseSymbol) diff --git a/python/matx/script/analysis/build_type_analysis.py b/python/matx/script/analysis/build_type_analysis.py index 6cfd0257..7d43cb84 100644 --- a/python/matx/script/analysis/build_type_analysis.py +++ b/python/matx/script/analysis/build_type_analysis.py @@ -30,7 +30,7 @@ def run(self, sc_ctx: context.ScriptContext): node_ctx = sc_ctx.main_node.context if isinstance(node_ctx, context.ClassContext): build_type = context.BuildType.JIT_OBJECT - elif isinstance(node_ctx, context.FunctionContext): + elif isinstance(node_ctx, (context.FunctionContext, context.InductorContext)): build_type = context.BuildType.FUNCTION else: raise RuntimeError("Only one-function, one-class source code is allowed") diff --git a/python/matx/script/context/__init__.py b/python/matx/script/context/__init__.py index 342af971..896630ad 100644 --- a/python/matx/script/context/__init__.py +++ b/python/matx/script/context/__init__.py @@ -23,3 +23,4 @@ from .class_context import ClassContext, GetClassAttr from .function_context import FunctionContext, FunctionType from .scope_context import ScopeContext +from .inductor_context import InductorContext diff --git a/python/matx/script/context/ast_node.py b/python/matx/script/context/ast_node.py index 70ee5c57..55b2c418 100644 --- a/python/matx/script/context/ast_node.py +++ b/python/matx/script/context/ast_node.py @@ -22,6 +22,7 @@ from matx._typed_ast import ast from .class_context import ClassContext from .function_context import FunctionContext +from .inductor_context import InductorContext from ... import ir as _ir @@ -49,7 +50,7 @@ def __init__(self, ): self.raw: Optional[type] = None self.span: Span = Span() self.ast: Optional[ast.AST] = None - self.context: Union[ClassContext, FunctionContext, None] = None + self.context: Union[ClassContext, FunctionContext, InductorContext, None] = None self.module: Optional[ModuleInfo] = None self.deps: Optional[List[ASTNode]] = None self.ir_schema = None diff --git a/python/matx/script/context/inductor_context.py b/python/matx/script/context/inductor_context.py new file mode 100644 index 00000000..cc3d1c99 --- /dev/null +++ b/python/matx/script/context/inductor_context.py @@ -0,0 +1,33 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +class InductorContext(object): + def __init__(self, + fn_name: str = '', + example_inputs_spec=None): + self.fn_name = fn_name + self.unbound_name = fn_name + self.return_type = None + self.arg_types = {} # Deferred? + self.example_inputs_spec = example_inputs_spec + + @property + def name(self): + return self.fn_name diff --git a/python/matx/script/inductor/__init__.py b/python/matx/script/inductor/__init__.py new file mode 100644 index 00000000..9c75d904 --- /dev/null +++ b/python/matx/script/inductor/__init__.py @@ -0,0 +1,78 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import inspect +from typing import List + +import torch + +from matx.torch_compiler.codegen import extract_inductor_code, matx_cpp_code_format +from .tensor_spec import TensorSpec +from .. import context, analysis +from ... import _ffi +from ... import ir +from ...env import MATX_DEV_MODE + + +def _embedded_inductor_ctx(compiling_obj, example_inputs): + code = _obtain_inductor_code(compiling_obj, example_inputs) + build_module = _ffi.get_global_func("embedded.build.c") + sc_ctx = context.ScriptContext() + sc_ctx.main_node.raw = compiling_obj + if isinstance(code, str): + code = code.encode() + sc_ctx.rt_module = build_module(code) + example_inputs_spec = [TensorSpec.from_tensor(inputs) for inputs in example_inputs] + sc_ctx.main_node.context = context.InductorContext(fn_name=compiling_obj.__name__, + example_inputs_spec=example_inputs_spec) + return sc_ctx + + +def _pass(sc_ctx: context.ScriptContext): + src_anls = analysis.SourceAnalysis() + src_anls.run(sc_ctx) + + +def _obtain_inductor_code(compiling_obj, example_inputs): + # compile the kernel and set the code + code, kernel_name, fake_output = extract_inductor_code(compiling_obj, example_inputs) + code = matx_cpp_code_format(code, kernel_name, example_inputs, fake_output) + return code + + +def from_source(compiling_obj: type, example_inputs: List[torch.Tensor]) -> context.ScriptContext: + try: + # TODO: allow generalized way to specify example_inputs + sc_ctx = _embedded_inductor_ctx(compiling_obj, example_inputs) + # set filename. + _pass(sc_ctx) + analysis.BuildTypeAnalysis().run(sc_ctx) + + # set args types. + # TODO: currently, we only support argument as NDArray. We may support nested inputs later + signature = inspect.signature(compiling_obj) + for param in signature.parameters.values(): + sc_ctx.main_node.context.arg_types[param.name] = ir.type.NDArrayType() + + return sc_ctx + except BaseException as e: + if MATX_DEV_MODE: + raise + else: + raise Exception(str(e)) from None diff --git a/python/matx/script/inductor/tensor_spec.py b/python/matx/script/inductor/tensor_spec.py new file mode 100644 index 00000000..40a229fc --- /dev/null +++ b/python/matx/script/inductor/tensor_spec.py @@ -0,0 +1,49 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +def convert_torch_dtype(dtype): + import torch + table = { + torch.int32: 'int32', + torch.int64: 'int64', + torch.float32: 'float32', + torch.float64: 'float64' + } + if dtype not in table: + raise NotImplementedError(f'Unsupport torch.Tensor dtype {dtype}') + + return table[dtype] + + +class TensorSpec(object): + def __init__(self, shape, dtype): + self.shape = tuple(shape) + self.dtype = dtype + + @classmethod + def from_tensor(cls, tensor): + import torch + assert isinstance(tensor, torch.Tensor) + return cls(shape=tuple(tensor.shape), dtype=convert_torch_dtype(tensor.dtype)) + + def __str__(self): + return str(self.shape) + ', ' + self.dtype + + def __repr__(self): + return f'TensorSpec({str(self)})' diff --git a/python/matx/toolchain.py b/python/matx/toolchain.py index 8758f557..b4185c06 100644 --- a/python/matx/toolchain.py +++ b/python/matx/toolchain.py @@ -40,6 +40,7 @@ USE_SO_CACHE = os.environ.get('MATX_USE_SO_CACHE', '').lower() != 'false' DISABLE_SCRIPT = os.environ.get('MATX_DISABLE_SCRIPT', '').lower() == 'true' +DISABLE_INDUCTOR = os.environ.get('MATX_DISABLE_INDUCTOR', '').lower() == 'true' DISABLE_GENERATE_CC = os.environ.get('MATX_DISABLE_GENERATE_CC', '').lower() == 'true' FLAG_COMPILED_OBJECT = object() @@ -251,6 +252,27 @@ def path_prefix(sc_ctx: context.ScriptContext): cache_md5)) +def path_prefix_inductor(sc_ctx: context.ScriptContext): + """inductor path_prefix encodes meta info from example_inputs""" + # mkdir LIB_PATH + from .__init__ import __version__ + _mk_lib_dir() + # code + sha1(libmatx.so) + commit_id(__version__) + dep_source_codes = "".join(dep_node.span.source_code for dep_node in sc_ctx.deps_node) + assert isinstance(sc_ctx.main_node.context, context.InductorContext) + example_inputs = sc_ctx.main_node.context.example_inputs_spec + example_inputs_str = ''.join([str(inputs) for inputs in example_inputs]) + cache_str = sc_ctx.main_node.span.source_code + dep_source_codes + cache_str += example_inputs_str + _LIB_SHA1 + __version__ + cache_md5 = hashlib.md5(cache_str.encode()).hexdigest()[:16] + file_name = os.path.splitext(os.path.basename(sc_ctx.main_node.span.file_name))[0] + return os.path.abspath('{}/lib{}_{}_{}_plugin_{}'.format(LIB_PATH, + file_name, + sc_ctx.main_node.span.lineno, + sc_ctx.main_node.context.name, + cache_md5)) + + def toolchain_path_prefix(sc_ctx: context.ScriptContext, toolchain_str: str): from .__init__ import __version__ # mkdir LIB_PATH @@ -296,21 +318,31 @@ def toolchain_build(sc_ctx: context.ScriptContext, toolchain: ToolChain): sc_ctx.dso_path = (sc_ctx.dso_path[0], so_path) -def build_dso(sc_ctx: context.ScriptContext, use_toolchain=False): +def build_dso(sc_ctx: context.ScriptContext, + use_toolchain=False, + compile_options=None, + make_path_prefix=None): rt_mod = sc_ctx.rt_module main_node_name = sc_ctx.main_node.context.name - base_path = path_prefix(sc_ctx) + if make_path_prefix is None: + make_path_prefix = path_prefix + + base_path = make_path_prefix(sc_ctx) with contrib.util.filelock(base_path): sopath = base_path + '.so' sopath_cxx11 = base_path + '_cxx11.so' + # TODO: need to unify the compile options base_options = [ "-std=c++14", "-O3", "-g", "-fdiagnostics-color=always", "-Werror=return-type"] + if compile_options is not None: + assert isinstance(compile_options, List) + base_options.extend(compile_options) cxx11_with_abi_options = base_options + ["-D_GLIBCXX_USE_CXX11_ABI=1"] cxx11_no_abi_options = base_options + ["-D_GLIBCXX_USE_CXX11_ABI=0"] sys_cc_path = contrib.cc.find_sys_cc_path() @@ -380,6 +412,41 @@ def script(compiling_obj, *, share=True, toolchain=None, bundle_args=None): raise ValueError('Unsupported build_type: {}'.format(result.build_type)) +def inductor(compiling_obj, example_inputs, *, share=True, toolchain=None, bundle_args=None): + if DISABLE_SCRIPT: + return compiling_obj + + from .script.inductor import from_source + + result: context.ScriptContext = from_source(compiling_obj, example_inputs) + + from torch._inductor import codecache + ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths( + include_pytorch=False, vec_isa=codecache.pick_vec_isa()) + + # TODO: check whether the following flags are handled by common flags + # codecache.get_shared() + optimization_flag = codecache.optimization_flags() + # codecache.cpp_flags() + # codecache.get_warning_all_flag() + # codecache.use_custom_generated_macros() + + torch_compiler_options = [] + flag_str_lst = [ipaths, lpaths, libs, macros, optimization_flag] + for flag_str in flag_str_lst: + torch_compiler_options.extend(flag_str.split()) + + build_dso(result, toolchain is not None, compile_options=torch_compiler_options, + make_path_prefix=path_prefix_inductor) + if toolchain is not None: + toolchain_build(result, toolchain) + + if result.build_type is context.BuildType.FUNCTION: + return make_jit_op_creator(result, share, bundle_args=bundle_args)() + else: + raise ValueError('Unsupported build_type: {}'.format(result.build_type)) + + def make_session(compiling_obj, method='__call__'): from . import pipeline diff --git a/python/matx/torch_compiler/__init__.py b/python/matx/torch_compiler/__init__.py new file mode 100644 index 00000000..97fcd247 --- /dev/null +++ b/python/matx/torch_compiler/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +minimum_torch_version = '2.0.0.dev' + +try: + import torch + + assert torch.__version__ >= minimum_torch_version + +except ModuleNotFoundError: + print(f'torch is not installed. matx.inductor requires torch >= {minimum_torch_version}') + raise +except AssertionError: + print(f'matx.inductor requires torch >= {minimum_torch_version}') + raise diff --git a/python/matx/torch_compiler/codegen/__init__.py b/python/matx/torch_compiler/codegen/__init__.py new file mode 100644 index 00000000..9ad89473 --- /dev/null +++ b/python/matx/torch_compiler/codegen/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .inductor import extract_inductor_code +from .matx_formatter import matx_cpp_code_format diff --git a/python/matx/torch_compiler/codegen/inductor/__init__.py b/python/matx/torch_compiler/codegen/inductor/__init__.py new file mode 100644 index 00000000..cc06f675 --- /dev/null +++ b/python/matx/torch_compiler/codegen/inductor/__init__.py @@ -0,0 +1,120 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Tuple + +import torch +import torch._inductor.compile_fx as compile_fx +from torch import fx +from torch._inductor.debug import DebugContext +from torch._inductor.virtualized import V + +""" +Use a global variable to hack the compile_fx_inner and record the compiled code. +This works in single process problem, but requires careful review in multi-processing +""" + + +class FakeCallableWithCode(): + code = None + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + def set_code(self, code): + self.code = code + + +fake_callable = FakeCallableWithCode() + + +@DebugContext.wrap +@torch.utils._python_dispatch._disable_current_modes() +def compile_fx_inner_cpu( + gm: torch.fx.GraphModule, + example_inputs: List[torch.Tensor], + cudagraphs=None, + num_fixed=0, + is_backward=False, + graph_id=None, +): + # lift the maximum depth of the Python interpreter stack + # to adapt large/deep models + compile_fx.sys.setrecursionlimit(max(compile_fx.sys.getrecursionlimit(), 2000)) + + V.debug.fx_graph(gm, example_inputs) + + shape_env = compile_fx._shape_env_from_inputs(example_inputs) + fake_mode = compile_fx.fake_mode_from_tensors( + example_inputs + ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + + with V.set_fake_mode(fake_mode): + compile_fx.pattern_matcher.fx_passes(gm) + V.debug.fx_graph_transformed(gm, example_inputs) + + graph = compile_fx.GraphLowering( + gm, + shape_env=shape_env, + num_static_inputs=num_fixed, + graph_id=graph_id, + ) + with V.set_graph_handler(graph): + graph.run(*example_inputs) + code = graph.codegen() + fake_callable.set_code(code) + + return fake_callable + + +def assert_tuple_of_tensors(tensors): + assert isinstance(tensors, Tuple) + for tensor in tensors: + assert isinstance(tensor, torch.Tensor), 'Each element in tensors must be a torch.Tensor' + + +from torch._subclasses import FakeTensor, FakeTensorMode + + +def extract_inductor_code(kernel, example_inputs): + # check kernel input and output. All the input must be a Tensor. The output must be a tuple of Tensor + # TODO: remove this constraints (long term) + assert isinstance(example_inputs, (List, Tuple)) + example_inputs = tuple(example_inputs) + assert_tuple_of_tensors(example_inputs) + fake_mode = FakeTensorMode() + fake_example_inputs = [FakeTensor.from_tensor(t, fake_mode=fake_mode) for t in example_inputs] + fake_output = kernel(*fake_example_inputs) + assert_tuple_of_tensors(fake_output) + + model = fx.symbolic_trace(kernel) + compile_fx.compile_fx( + model, + example_inputs_=fake_example_inputs, + inner_compile=compile_fx_inner_cpu) + + code = fake_callable.code + + # By default, Pytorch compiles a Python module with all the C++ kernel with unified name kernel. + # The actual kernel name should be kernel.__name__. + # TODO: fix this after rewriting inductor codegen to all C++ instead of a Python module + kernel_name = kernel.__name__ + + # fake_output is used + return code, kernel_name, fake_output diff --git a/python/matx/torch_compiler/codegen/matx_formatter.py b/python/matx/torch_compiler/codegen/matx_formatter.py new file mode 100644 index 00000000..c30fb2ea --- /dev/null +++ b/python/matx/torch_compiler/codegen/matx_formatter.py @@ -0,0 +1,351 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Utilities to format kernel code generated by inductor to a JITOp +""" + +import copy +import logging +from typing import List + +import torch + +from .utils import cpp_parse + +log = logging.getLogger(__name__) + +MAGIC_NUMBER = '2_71828182846' + +MATX_INCLUDE = ''' +#include "matxscript/runtime/codegen_all_includes.h" +#include + +using namespace ::matxscript::runtime; +extern "C" void* __matxscript_module_ctx = NULL; + +extern "C" MATX_DLL MATXScriptFuncRegistry __matxscript_func_registry__; + + + +''' + +SESSION_HANLDER = cpp_parse.CPPArg(name=f'handle_{MAGIC_NUMBER}', + type=cpp_parse.CPPType(name='void', is_pointer=True)) +SESSION_HANLDER_WITH_DEAFULT = cpp_parse.CPPArg( + name=f'handle_{MAGIC_NUMBER}', type=cpp_parse.CPPType( + name='void', is_pointer=True), default_val='((void*)(int64_t)0)') + +CREATE_NDARRAY_IMPLEMENTATION = ''' +NDArray createNDArray(const std::string& dtype, const std::string& device, const List& arg_shape) { + Unicode dtype_str(UTF8Decode(dtype)); + Unicode ctx_str(UTF8Decode(device)); + + auto a = Kernel_NDArray::make(0., arg_shape, dtype_str, ctx_str); + // set impl to torch.Tensor + a.SetImpl(NDArray::Impl::torchTensor); + return a; +} +''' + + +def generate_ndarray_arg_cast(arg_name, arg_index, message='TODO'): + return f'internal::TypeAsHelper::run(({arg_name}[{arg_index}]), __FILE__, __LINE__, "{message}", "{message}")' + + +def get_c_api(kernel_name: str, args: List[cpp_parse.CPPArg], has_return_value) -> str: + template_with_return = ''' +int {}__c_api(MATXScriptAny* args, int num_args, MATXScriptAny* out_ret_value, void* resource_handle = nullptr) +{{ + TArgs args_t(args, num_args); + + if (num_args > 0 && args[num_args - 1].code == TypeIndex::kRuntimeKwargs) {{ + string_view arg_names[{}] {{{}}}; + KwargsUnpackHelper helper("{}", arg_names, {}, nullptr, 0); + RTView pos_args[{}]; + helper.unpack(pos_args, args, num_args); // /Users/bytedance/Developer/open_source_library/matxscript/examples/simple_function.py:5 + + auto ret = {}({}, + {}resource_handle); + RTValue(std::move(ret)).MoveToCHost(out_ret_value); + }} else {{ + switch(num_args) {{ + case {}: {{ + auto ret = {}({}, + {}resource_handle); // /Users/bytedance/Developer/open_source_library/matxscript/examples/simple_function.py:5 + RTValue(std::move(ret)).MoveToCHost(out_ret_value); + }} break; + default: {{THROW_PY_TypeError("TODO");}} break; // /Users/bytedance/Developer/open_source_library/matxscript/examples/simple_function.py:5 + }} + }} + + return 0; +}} +''' + assert has_return_value + template = template_with_return + + num_args = len(args) + arg_names_concat_str = ', '.join([f'"{arg.name}"' for arg in args]) + args_dtype = [arg.type.name for arg in args] + + pos_arg_cast_lst = [] + args_t_cast_lst = [] + for arg_index in range(num_args): + pos_arg_cast_lst.append(generate_ndarray_arg_cast('pos_args', arg_index)) + args_t_cast_lst.append(generate_ndarray_arg_cast('args_t', arg_index)) + + kernel_name_indentation = len(kernel_name) * ' ' + if has_return_value: + return_name_indentation = ' ' * 11 + else: + return_name_indentation = '' + pos_arg_cast_indentation = '\n ' + kernel_name_indentation + return_name_indentation + args_t_cast_indentation = '\n ' + kernel_name_indentation + return_name_indentation + pos_arg_cast = (',' + pos_arg_cast_indentation).join(pos_arg_cast_lst) + args_t_cast = (',' + args_t_cast_indentation).join(args_t_cast_lst) + + return template.format( + kernel_name, + num_args, + arg_names_concat_str, + kernel_name, + num_args, + num_args, + kernel_name, + pos_arg_cast, + kernel_name_indentation, + num_args, + kernel_name, + args_t_cast, + kernel_name_indentation) + + +def get_registration_str(kernel_name): + # TODO: currently, only 1 function is here. + template = ''' +extern "C" {{ + +MATX_DLL MATXScriptBackendPackedCFunc __matxscript_func_array__[] = {{ + (MATXScriptBackendPackedCFunc){}__c_api, +}}; +MATX_DLL MATXScriptFuncRegistry __matxscript_func_registry__ = {{ + "1\\000{}\\000", __matxscript_func_array__, +}}; + +}} // extern C + +extern "C" {{ + +MATX_DLL const char* __matxscript_closures_names__ = "1\\000{}\\000"; + +}} // extern C + + ''' + return template.format(kernel_name, kernel_name, kernel_name) + + +def get_c_api_declare(kernel_name): + return f'int {kernel_name}__c_api(MATXScriptAny*, int, MATXScriptAny*, void*);' + + +def extract_cpp_code(code: str): + return code.split("'''")[1][1:-1] + + +def split_include_kernel(code): + first_newline_idx = code.find('\n') + include_code_str = code[:first_newline_idx] + kernel_code_str = code[first_newline_idx + 1:] + return include_code_str, kernel_code_str + + +def split_declaration_body(kernel_code_str): + first_open_bracket = kernel_code_str.find('{') + kernel_declaration_str = kernel_code_str[:first_open_bracket] + kernel_body_str = kernel_code_str[first_open_bracket:] + return kernel_declaration_str, kernel_body_str + + +def generate_kernel_wrapper_declaration(kernel_name, example_inputs): + return_type = cpp_parse.CPPType(name='Tuple', is_pointer=False) + args = [] + for i in range(len(example_inputs)): + arg = cpp_parse.CPPArg( + name=f'in_ptr{i}', + type=cpp_parse.CPPType( + name='NDArray', + is_pointer=False), + is_const=False, + is_restricted=False) + args.append(arg) + kernel_wrapper_declaration = cpp_parse.CPPDeclaration(func_name=kernel_name, + return_type=return_type, + args=args, + is_extern_c=False) + return kernel_wrapper_declaration + + +def generate_ndarray_allocate_statement( + output_name: str, + dtype: str, + device: str, + shape: List[int]): + assert dtype in ['int32', 'int64', 'float32', 'float64'] + assert device == 'cpu' + assert isinstance(shape, List) + for shape_int in shape: + assert isinstance(shape_int, int) + + shape = [str(shape_int) for shape_int in shape] + shape_str = ', '.join(shape) + + return f'NDArray {output_name} = createNDArray("{dtype}", "{device}", {{{shape_str}}});' + + +def generate_ndarray_cast(var_name, dtype): + return f'({dtype}*){var_name}.Data<{dtype}>()' + + +def generate_kernel_wrapper_return(fake_output): + output_str = [f'out_ptr{i}' for i in range(len(fake_output))] + output_str = ','.join(output_str) + return f'return Kernel_Tuple::make(std::initializer_list{{{output_str}}});' + + +TORCH_DTYPE_TO_NDARRAY_DTYPE = { + torch.float32: 'float32', + torch.float64: 'float64', + torch.int32: 'int32', + torch.int64: 'int64' +} + + +def generate_kernel_wrapper_body(kernel_declaration: cpp_parse.CPPDeclaration, + fake_output: List[torch.Tensor]): + # step 0: obtain output args from kernel_declaration + + # step 1: allocate output NDArray + ndarray_allocate_statements = [] + for i, output in enumerate(fake_output): + assert output.dtype in TORCH_DTYPE_TO_NDARRAY_DTYPE + dtype = TORCH_DTYPE_TO_NDARRAY_DTYPE[output.dtype] + + ndarray_allocate_statement = generate_ndarray_allocate_statement(output_name=f'out_ptr{i}', + dtype=dtype, + device=str(output.device), + shape=list(output.shape)) + ndarray_allocate_statements.append(ndarray_allocate_statement) + + ndarray_allocate_statements = '\n'.join(ndarray_allocate_statements) + '\n\n' + + # step 2: invoke kernel + kernel_invoke_param = [] + for arg in kernel_declaration.args: + kernel_invoke_param.append(generate_ndarray_cast(var_name=arg.name, dtype=arg.type.name)) + + num_space = 10 + delimiter = ',\n' + ' ' * 10 + kernel_invoke_param_str = delimiter.join(kernel_invoke_param) + kernel_invoke_str = kernel_declaration.func_name + '(' + '\n' + ' ' * num_space + \ + kernel_invoke_param_str + '\n' + ');' + '\n' + + # step 3: return output as a Tuple + return_str = generate_kernel_wrapper_return(fake_output) + + # step 4: add bracket + final_result = '\n{\n' + ndarray_allocate_statements + kernel_invoke_str + return_str + '\n}' + + return final_result + + +def matx_cpp_code_format(code: str, kernel_name: str, + example_inputs: List[torch.Tensor], + fake_output: List[torch.Tensor]) -> str: + code = extract_cpp_code(code) + # split include and kernel code + + include_code_str, kernel_code_str = split_include_kernel(code) + # add matx include + include_code_str += MATX_INCLUDE + + # extract kernel declaration + kernel_declaration_str, kernel_body_str = split_declaration_body(kernel_code_str) + + kernel_declaration = cpp_parse.parse_cpp_declaration(kernel_declaration_str) + kernel_return_type = kernel_declaration.return_type.name + assert kernel_return_type == 'void', f'The kernel return type must be void, Got {kernel_return_type}' + + # TODO: currently, we simply add magic number to avoid conflict + kernel_declaration.func_name += MAGIC_NUMBER + kernel_code_str = str(kernel_declaration) + kernel_body_str + + # here, we keep the original kernel and add a wrapper + kernel_wrapper_declaration = generate_kernel_wrapper_declaration(kernel_name, example_inputs) + kernel_wrapper_body = generate_kernel_wrapper_body(kernel_declaration, fake_output) + + kernel_wrapper_declaration_without_default = copy.deepcopy(kernel_wrapper_declaration) + kernel_wrapper_declaration_without_default.append_arg(SESSION_HANLDER) + kernel_wrapper_declaration_with_default = copy.deepcopy(kernel_wrapper_declaration) + kernel_wrapper_declaration_with_default.append_arg(SESSION_HANLDER_WITH_DEAFULT) + + # create all the declarations strings + CREATE_NDARRAY_DECLARATION = split_declaration_body(CREATE_NDARRAY_IMPLEMENTATION)[0] + ';' + + function_declaration = [ + CREATE_NDARRAY_DECLARATION, + str(kernel_wrapper_declaration_with_default) + ';', + str(kernel_declaration) + ';', + get_c_api_declare( + kernel_wrapper_declaration.func_name)] + + function_declaration_str = '\n\n'.join(function_declaration) + '\n' + + # create all the kernel implementation strings including + # 1. create ndarray. 2. kernel wrapper, 3. kernel, 4. kernel-c-api + kernel_wrapper = str(kernel_wrapper_declaration_without_default) + kernel_wrapper_body + kernel_c_api_impl_str = get_c_api( + kernel_name=kernel_wrapper_declaration.func_name, + args=kernel_wrapper_declaration.args, + has_return_value=kernel_wrapper_declaration.return_type.name != 'void') + + implementations = [ + CREATE_NDARRAY_IMPLEMENTATION, + kernel_wrapper, + kernel_code_str, + kernel_c_api_impl_str] + implementations_str = '\n\n'.join(implementations) + '\n' + + # add namespace + kernel_code_str = [ + 'namespace {', + function_declaration_str, + implementations_str, + '} // namespace'] + kernel_code_str = '\n\n'.join(kernel_code_str) + + # registration str + registration_code_str = get_registration_str(kernel_name=kernel_wrapper_declaration.func_name) + + # final code + final_code = [include_code_str, kernel_code_str, registration_code_str] + + final_code = '\n\n'.join(final_code) + + return final_code diff --git a/python/matx/torch_compiler/codegen/utils/__init__.py b/python/matx/torch_compiler/codegen/utils/__init__.py new file mode 100644 index 00000000..9e19ab85 --- /dev/null +++ b/python/matx/torch_compiler/codegen/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/matx/torch_compiler/codegen/utils/cpp_parse.py b/python/matx/torch_compiler/codegen/utils/cpp_parse.py new file mode 100644 index 00000000..499c0b9f --- /dev/null +++ b/python/matx/torch_compiler/codegen/utils/cpp_parse.py @@ -0,0 +1,160 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import dataclasses +from typing import List, Union + + +@dataclasses.dataclass +class CPPType(object): + name: str = None + is_pointer: bool = False + + def __str__(self): + result = self.name + if self.is_pointer: + result += '*' + + return result + + +@dataclasses.dataclass +class CPPArg(object): + name: str = None + type: CPPType = CPPType() + is_const: bool = False + is_restricted: bool = False + default_val: Union[str, None] = None + + def __str__(self): + result = [] + if self.is_const: + result.append('const') + result.append(str(self.type)) + if self.is_restricted: + result.append('__restrict__') + result.append(self.name) + + if self.default_val is not None: + result.append(f'= {self.default_val}') + + return ' '.join(result) + + +def parse_cpp_arg(cpp_arg_str: str) -> CPPArg: + """Parse the C++ arg from a string such as const float* __restrict__ a = null_ptr + + :param cpp_arg_str: the string of the argument + :return: a CPPArg dataclass + """ + + cpp_arg = CPPArg() + + # find if there is a default value + if '=' in cpp_arg_str: + cpp_arg_str, default_val = cpp_arg_str.split('=') + default_val = default_val.replace(' ', '') + cpp_arg.default_val = default_val + + word = cpp_arg_str.split() + + cpp_arg.name = word[-1] + + for w in word[:-1]: + if w == 'const': + cpp_arg.is_const = True + elif w == '*': + cpp_arg.type.is_pointer = True + elif w == '__restrict__': + cpp_arg.is_restricted = True + else: + # type + if w[-1] == '*': + cpp_arg.type.is_pointer = True + w = w[:-1] # remove * + cpp_arg.type.name = w + + return cpp_arg + + +@dataclasses.dataclass +class CPPDeclaration(object): + func_name: str = None + return_type: CPPType = CPPType() + args: List[CPPArg] = dataclasses.field(default_factory=list) + is_extern_c: bool = False + + def append_arg(self, arg: CPPArg): + self.args.append(arg) + + def __str__(self): + result = [] + if self.is_extern_c: + result.append('extern "C"') + result.append(str(self.return_type)) + result.append(self.func_name) + + front = ' '.join(result) + num_spaces = len(front) + 1 + interval = ',\n' + ' ' * num_spaces + + args_str = interval.join([str(arg) for arg in self.args]) + + return front + '(' + args_str + ')' + + +def parse_cpp_declaration(cpp_declaration_str: str) -> CPPDeclaration: + """Parse the CPP declaration in string and return a CPPDeclaration. + + :param cpp_declaration_str: + :return: + """ + cpp_declaration = CPPDeclaration() + + identifier_return_name, cpp_arg_str = cpp_declaration_str.split('(') + cpp_arg_str = cpp_arg_str.split(')')[0] + cpp_arg_str_lst = cpp_arg_str.split(',') + # arguments + for cpp_arg_str in cpp_arg_str_lst: + cpp_declaration.args.append(parse_cpp_arg(cpp_arg_str)) + + # process return type and function name + identifier_return_name_lst = identifier_return_name.split() + if identifier_return_name_lst[0] == 'extern' and identifier_return_name_lst[1] == '"C"': + cpp_declaration.is_extern_c = True + identifier_return_name_lst = identifier_return_name_lst[2:] + + cpp_declaration.func_name = identifier_return_name_lst[-1] + # remove func_name + return_type_str_lst = identifier_return_name_lst[:-1] + + if len(return_type_str_lst) == 1: + return_type_str = return_type_str_lst[0] + if return_type_str[-1] == '*': + cpp_declaration.return_type.name = return_type_str[:-1] + cpp_declaration.return_type.is_pointer = True + else: + cpp_declaration.return_type.name = return_type_str + else: + assert len(return_type_str_lst) == 2 + assert return_type_str_lst[-1] == '*' + cpp_declaration.return_type.name = return_type_str_lst[0] + cpp_declaration.return_type.is_pointer = True + + return cpp_declaration diff --git a/python/matx/torch_compiler/tests/simple_inductor.py b/python/matx/torch_compiler/tests/simple_inductor.py new file mode 100644 index 00000000..f611743f --- /dev/null +++ b/python/matx/torch_compiler/tests/simple_inductor.py @@ -0,0 +1,48 @@ +import json + +import numpy as np + +import matx +import torch + + +@matx.inductor_script(example_inputs=[torch.from_numpy(np.random.randn(5).astype(np.int32)), + torch.from_numpy(np.random.randn(5).astype(np.int32))]) +def add_relu(a: matx.NDArray, b: matx.NDArray): + c = a + b + c = torch.nn.functional.relu(c) + return c, + + +@matx.script +def add_json(a: str, b: str) -> str: + """ + Assume a and b is a json containing 10 digits. We would like to add them and return another json + """ + a_list = json.loads(a) + b_list = json.loads(b) + + a_tensor = matx.NDArray(arr=a_list, shape=[5], dtype='int32') + b_tensor = matx.NDArray(arr=b_list, shape=[5], dtype='int32') + + c_tensor = add_relu(a_tensor, b_tensor)[0] + + result_lst = c_tensor.tolist() + + return json.dumps(result_lst) + + +if __name__ == '__main__': + a_tensor = matx.NDArray(arr=[1, 2, 3, 4, 5], shape=[5], dtype='int32') + b_tensor = matx.NDArray(arr=[6, 7, 8, 8, 10], shape=[5], dtype='int32') + + a_tensor = a_tensor.torch(copy=True) + + c_tensor = add_relu(a_tensor, b_tensor) + print(c_tensor) + + print(f'Pytorch version {torch.__version__}') + a = json.dumps([1, 2, 3, 4, 5]) + b = json.dumps([6, 7, 8, 9, 10]) + result = add_json(a, b) + print(result) diff --git a/test/inductor/test_basic.py b/test/inductor/test_basic.py new file mode 100644 index 00000000..f28f981d --- /dev/null +++ b/test/inductor/test_basic.py @@ -0,0 +1,56 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +import matx +import torch +import numpy as np + + +class BasicTests(unittest.TestCase): + + def test_basics(self): + def add_relu(a, b): + c = a + b + c = torch.nn.functional.relu(c) + return c, + + sizes = [(5,), (10,), (2, 3), (4, 5, 6)] + dtypes = [np.float32, np.float64, np.int32, np.int64] + + for size in sizes: + for dtype in dtypes: + a_numpy = np.random.randn(*size).astype(dtype) + b_numpy = np.random.randn(*size).astype(dtype) + + example_inputs = [torch.from_numpy(np.random.randn(*size).astype(dtype)), + torch.from_numpy(np.random.randn(*size).astype(dtype))] + + add_relu_kernel = matx.inductor(example_inputs)(add_relu) + + a_tensor = torch.from_numpy(a_numpy) + b_tensor = torch.from_numpy(b_numpy) + + c_tensor_expected = add_relu(a_tensor, b_tensor)[0] + c_tensor = add_relu_kernel(a_tensor, b_tensor)[0] + torch.testing.assert_close(c_tensor_expected, c_tensor) + + +if __name__ == '__main__': + unittest.main()