Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .github/workflows/test_py_inductor.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Test Python Inductor

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.8'
- name: Prepare PyTorch 2.0 nightly
run: pip3 install --pre torch==2.0.0.dev20230205 --index-url https://download.pytorch.org/whl/nightly/cpu
- name: Echo GCC version
run: gcc --version
- name: Install MATXScript Requirements
run: pip3 install -r python/requirements.txt
- name: Python Inductor Test
run: bash ci/run_py_inductor_test.sh
50 changes: 50 additions & 0 deletions ci/run_py_inductor_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env bash
# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -xue
set -o pipefail

THIS_PATH=$(cd $(dirname "$0"); pwd)
ROOT_PATH=${THIS_PATH}/../

###############################################################################
# build all shared target
###############################################################################
cd "${ROOT_PATH}" || exit 1
BUILD_TESTING=OFF BUILD_BENCHMARK=OFF CPPFLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" bash ci/build_lib.sh

###############################################################################
# install requirements
###############################################################################
PYTHON_MODULE_PATH=${ROOT_PATH}/python
cd "${PYTHON_MODULE_PATH}"
pip3 install -r requirements.txt

###############################################################################
# find all test script
###############################################################################
PYTHONPATH=${PYTHONPATH:-}
TEST_SCRIPT_PATH=${ROOT_PATH}/test/inductor
cd "${TEST_SCRIPT_PATH}"
# shellcheck disable=SC2045
for script_file in $(ls test_*.py); do
echo "test script: ${script_file}"
PYTHONPATH="${ROOT_PATH}/python:${PYTHONPATH}" python3 "${script_file}"
done
19 changes: 18 additions & 1 deletion python/matx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from . import vision
from . import tools


# APIs
__all__ = [
# functions
Expand All @@ -41,6 +40,7 @@
"trace",
"script",
"script_embedded_class",
"inductor",
"save",
"load",
"get_cflags",
Expand Down Expand Up @@ -352,6 +352,23 @@ def script(compiling_obj, *args, backend=None, **kwargs):
return toolchain.script(compiling_obj, *args, **kwargs)


def inductor(example_inputs, **kwargs):
"""

Args:
example_inputs: any nested structure of torch.Tensor that passed into the kernel
**kwargs: other keyword arguments passed into toolchain.inductor

Returns: a wrapper that compiles the compiling_obj into a JIT FUNCTION

"""

def inner_inductor(compiling_obj):
return toolchain.inductor(compiling_obj, example_inputs, **kwargs)

return inner_inductor


def script_embedded_class(code, is_path=False):
return toolchain.script_embedded_class(code, is_path)

Expand Down
11 changes: 9 additions & 2 deletions python/matx/contrib/cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,16 @@ def find_sys_cc_path():
raise RuntimeError("win32 is not supported")
elif sys.platform.startswith('darwin'):
# maybe we can use clang++
cc_bin = "g++"
# prioritized compiler defined in CXX
if 'CXX' in os.environ:
cc_bin = os.environ['CXX']
else:
cc_bin = "g++"
else:
cc_bin = "g++"
if 'CXX' in os.environ:
cc_bin = os.environ['CXX']
else:
cc_bin = "g++"
return cc_bin


Expand Down
14 changes: 14 additions & 0 deletions python/matx/pipeline/_register_conveter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
# specific language governing permissions and limitations
# under the License.


try:
import torch
import torch.utils.dlpack

HAS_TORCH = True
except:
HAS_TORCH = False

import matx
from .._ffi._selector import _set_fast_pipeline_object_converter
from .._ffi._selector import _set_class_symbol
from .symbol import BaseSymbol
Expand All @@ -29,9 +39,13 @@ def _pipeline_object_converter(value):
return value.native_op
if isinstance(value, OpKernel):
return value.native_op
if HAS_TORCH and isinstance(value, torch.Tensor):
return matx.array.from_dlpack(torch.utils.dlpack.to_dlpack(value))
return value


_PipelineClasses = (JitObject, OpKernel,)
if HAS_TORCH:
_PipelineClasses += (torch.Tensor,)
_set_fast_pipeline_object_converter(_PipelineClasses, _pipeline_object_converter)
_set_class_symbol(BaseSymbol)
2 changes: 1 addition & 1 deletion python/matx/script/analysis/build_type_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def run(self, sc_ctx: context.ScriptContext):
node_ctx = sc_ctx.main_node.context
if isinstance(node_ctx, context.ClassContext):
build_type = context.BuildType.JIT_OBJECT
elif isinstance(node_ctx, context.FunctionContext):
elif isinstance(node_ctx, (context.FunctionContext, context.InductorContext)):
build_type = context.BuildType.FUNCTION
else:
raise RuntimeError("Only one-function, one-class source code is allowed")
Expand Down
1 change: 1 addition & 0 deletions python/matx/script/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .class_context import ClassContext, GetClassAttr
from .function_context import FunctionContext, FunctionType
from .scope_context import ScopeContext
from .inductor_context import InductorContext
3 changes: 2 additions & 1 deletion python/matx/script/context/ast_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from matx._typed_ast import ast
from .class_context import ClassContext
from .function_context import FunctionContext
from .inductor_context import InductorContext
from ... import ir as _ir


Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(self, ):
self.raw: Optional[type] = None
self.span: Span = Span()
self.ast: Optional[ast.AST] = None
self.context: Union[ClassContext, FunctionContext, None] = None
self.context: Union[ClassContext, FunctionContext, InductorContext, None] = None
self.module: Optional[ModuleInfo] = None
self.deps: Optional[List[ASTNode]] = None
self.ir_schema = None
Expand Down
33 changes: 33 additions & 0 deletions python/matx/script/context/inductor_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


class InductorContext(object):
def __init__(self,
fn_name: str = '<unknown>',
example_inputs_spec=None):
self.fn_name = fn_name
self.unbound_name = fn_name
self.return_type = None
self.arg_types = {} # Deferred?
self.example_inputs_spec = example_inputs_spec

@property
def name(self):
return self.fn_name
78 changes: 78 additions & 0 deletions python/matx/script/inductor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import inspect
from typing import List

import torch

from matx.torch_compiler.codegen import extract_inductor_code, matx_cpp_code_format
from .tensor_spec import TensorSpec
from .. import context, analysis
from ... import _ffi
from ... import ir
from ...env import MATX_DEV_MODE


def _embedded_inductor_ctx(compiling_obj, example_inputs):
code = _obtain_inductor_code(compiling_obj, example_inputs)
build_module = _ffi.get_global_func("embedded.build.c")
sc_ctx = context.ScriptContext()
sc_ctx.main_node.raw = compiling_obj
if isinstance(code, str):
code = code.encode()
sc_ctx.rt_module = build_module(code)
example_inputs_spec = [TensorSpec.from_tensor(inputs) for inputs in example_inputs]
sc_ctx.main_node.context = context.InductorContext(fn_name=compiling_obj.__name__,
example_inputs_spec=example_inputs_spec)
return sc_ctx


def _pass(sc_ctx: context.ScriptContext):
src_anls = analysis.SourceAnalysis()
src_anls.run(sc_ctx)


def _obtain_inductor_code(compiling_obj, example_inputs):
# compile the kernel and set the code
code, kernel_name, fake_output = extract_inductor_code(compiling_obj, example_inputs)
code = matx_cpp_code_format(code, kernel_name, example_inputs, fake_output)
return code


def from_source(compiling_obj: type, example_inputs: List[torch.Tensor]) -> context.ScriptContext:
try:
# TODO: allow generalized way to specify example_inputs
sc_ctx = _embedded_inductor_ctx(compiling_obj, example_inputs)
# set filename.
_pass(sc_ctx)
analysis.BuildTypeAnalysis().run(sc_ctx)

# set args types.
# TODO: currently, we only support argument as NDArray. We may support nested inputs later
signature = inspect.signature(compiling_obj)
for param in signature.parameters.values():
sc_ctx.main_node.context.arg_types[param.name] = ir.type.NDArrayType()

return sc_ctx
except BaseException as e:
if MATX_DEV_MODE:
raise
else:
raise Exception(str(e)) from None
49 changes: 49 additions & 0 deletions python/matx/script/inductor/tensor_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022 ByteDance Ltd. and/or its affiliates.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

def convert_torch_dtype(dtype):
import torch
table = {
torch.int32: 'int32',
torch.int64: 'int64',
torch.float32: 'float32',
torch.float64: 'float64'
}
if dtype not in table:
raise NotImplementedError(f'Unsupport torch.Tensor dtype {dtype}')

return table[dtype]


class TensorSpec(object):
def __init__(self, shape, dtype):
self.shape = tuple(shape)
self.dtype = dtype

@classmethod
def from_tensor(cls, tensor):
import torch
assert isinstance(tensor, torch.Tensor)
return cls(shape=tuple(tensor.shape), dtype=convert_torch_dtype(tensor.dtype))

def __str__(self):
return str(self.shape) + ', ' + self.dtype

def __repr__(self):
return f'TensorSpec({str(self)})'
Loading