Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
primary_registry: Optional[dict] = None,
aot_autograd_decomposition: Optional[dict] = None,
verbose=False,
enable_profile=False,
enable_external_calls: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
self._func_name = func_name
self._aot_autograd_decomposition = aot_autograd_decomposition
self._verbose = verbose
self._enable_profile = enable_profile
self._enable_external_calls = enable_external_calls
self._imported_graphs = []
self._ops_registry = {}
Expand Down Expand Up @@ -537,6 +539,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
self._func_name,
DeviceType.CPU,
self._verbose,
self._enable_profile,
self._enable_external_calls,
)
graph._params_ref = params_flat
Expand Down
87 changes: 82 additions & 5 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mlir.passmanager import *
from mlir.execution_engine import *
from mlir import runtime as rt
from mlir.dialects import arith, llvm

from .operation import *
from .type import *
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__(
func_name: str,
device: DeviceType = DeviceType.CPU,
verbose=False,
enable_profile: bool = False,
enable_external_calls: bool = False,
) -> None:
"""
Expand All @@ -132,6 +134,7 @@ def __init__(
self._imported_module = None
self._params_ref = None
self._verbose = verbose
self._enable_profile = enable_profile
self._ops_registry = ops_registry
self._func_name = func_name
self._ctx = ir.Context()
Expand Down Expand Up @@ -242,15 +245,13 @@ def displace_node(self, node: Op, newnode: Op):
self.node_table.pop(node.name)
self.node_table[newnode.name] = newnode


def displace_node_with_chain(self, node: Op, chain: list[Op]):
"""
Replaces an existing node with a chain of new nodes.
- The first node is taken to be the "head" of the chain, and all parents of the
current node will have this node as their child instead of `node`
- The last node is taken to be the "tail" of the chain, and all children of `node`
will have this node as their parent instead.

Args:
node (Op): The operation to be replaced.
chain (list[Op]): The a list of nodes to be inserted instead of Op
Expand Down Expand Up @@ -281,7 +282,7 @@ def displace_node_with_chain(self, node: Op, chain: list[Op]):
node._children.clear()

node_idx = self._body.index(node)
self._body = self.body[:node_idx] + chain + self.body[node_idx+1:]
self._body = self.body[:node_idx] + chain + self.body[node_idx + 1 :]

def init_op_group(self):
"""
Expand Down Expand Up @@ -358,6 +359,7 @@ def lower_to_top_level_ir(self):
False,
self.device,
verbose=self._verbose,
enable_profile=self._enable_profile,
enable_external_calls=self._enable_external_calls,
)
self._imported_module = fx_importer.import_graph()
Expand Down Expand Up @@ -474,6 +476,7 @@ def __init__(
do_param_pack: bool = False,
device: DeviceType = DeviceType.CPU,
verbose=False,
enable_profile=True,
enable_external_calls: bool = False,
):
"""
Expand All @@ -495,6 +498,7 @@ def __init__(
self._params = params
self._inputs = inputs
self._verbose = verbose
self._enable_profile = enable_profile
self._do_param_pack = do_param_pack
self._param_packs = []
self._num_input_visited = 0
Expand Down Expand Up @@ -571,7 +575,44 @@ def import_graph(self) -> ir.Module:
mlir.ir.Module: An MLIR module in high-level dialects.
"""
assert self._do_param_pack == False
node_names = [node.name for node in self._body]
with ir.InsertionPoint(self._module.body):
if self._enable_profile:
# Add clock func signature
f64 = ir.F64Type.get()
ptr = llvm.PointerType.get()
# func.func private @rtclock() -> f64
func.FuncOp(
name="rtclock",
type=ir.FunctionType.get(inputs=[], results=[f64]),
visibility="private",
)
# func.func private @record_timing(!llvm.ptr, f64)
func.FuncOp(
name="record_timing",
type=ir.FunctionType.get(inputs=[ptr, f64], results=[]),
visibility="private",
)

width = len(str(len(node_names)))
for i, name in enumerate(node_names):
sym_name = f"op_name_{i:0{width}d}_{name}"
content = sym_name + "\00"
content_type = ir.Type.parse(
f"!llvm.array<{len(content)} x i8>"
)
linkage_attr = ir.Attribute.parse("#llvm.linkage<private>")
llvm.GlobalOp(
content_type,
ir.StringAttr.get(sym_name),
linkage_attr,
constant=True,
value=ir.StringAttr.get(content),
addr_space=ir.IntegerAttr.get(
ir.IntegerType.get_signless(32), 0
), # 关键字参数
)

arguments = []
inputs = self._params + self._inputs
for arg in inputs:
Expand All @@ -589,10 +630,21 @@ def import_graph(self) -> ir.Module:
@func.FuncOp.from_py_func(*arguments, name=self._func_name)
def generated_func(*args):
args_list = list(args)
func_op = self._module.body.operations[0]
for node in self._body:
func_op_index = (
len(node_names) + 2 if self._enable_profile else 0
)
func_op = self._module.body.operations[func_op_index]
for i, node in enumerate(self._body):
if node in extern_func:
continue

if self._enable_profile:
# %t_start = call @rtclock() : () -> f64
f64_type = ir.F64Type.get()
t_start = func.CallOp(
[f64_type], ir.FlatSymbolRefAttr.get("rtclock"), []
).results[0]

old_ops = [op for op in func_op.body.blocks[0].operations]
if isinstance(node, OutputOp):
output_node_args = node.args
Expand Down Expand Up @@ -625,6 +677,31 @@ def generated_func(*args):
print(op)
print("")

if self._enable_profile:
# %t_end = call @rtclock() : () -> f64
t_end = func.CallOp(
[f64_type], ir.FlatSymbolRefAttr.get("rtclock"), []
).results[0]

# %duration = arith.subf %t_end, %t_start : f64
duration = arith.SubFOp(t_end, t_start).result

# %name_ptr = llvm.mlir.addressof @op_name_post_attn_layernorm : !llvm.ptr
ptr_type = llvm.PointerType.get()
name_ptr = llvm.AddressOfOp(
ptr_type,
ir.FlatSymbolRefAttr.get(
f"op_name_{i:0{width}d}_{node.name}"
),
).result

# call @record_timing(%name_ptr, %duration) : (!llvm.ptr, f64) -> ()
func.CallOp(
[],
ir.FlatSymbolRefAttr.get("record_timing"),
[name_ptr, duration],
)

return self._symbol_table.get(("output", 0))

# Generate external function declarations for CallExternalOp nodes (only if enabled)
Expand Down