From d839c2c1d0df46da9df400cdbf9481240c55bfa7 Mon Sep 17 00:00:00 2001 From: Mingzhu Yan Date: Tue, 13 Jan 2026 09:58:17 +0000 Subject: [PATCH] [frontend]: automatically inject profiling instructions in DynamoCompiler This commit introduces the `enable_profile` option to automatically measure the execution time of each graph node. When enabled, the DynamoCompiler performs the following instrumentations after `lower_to_top_level_ir`: - Inject `rtclock` for timestamp acquire and `record_timing` for data record - Injects global strings for each node in format "op_name_{node_index}_{node_name}" - Injects timing probes around each node to calculate and record the elapsed execution time. --- frontend/Python/frontend.py | 3 ++ frontend/Python/graph/graph.py | 87 ++++++++++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index c8e5ba6958..70f9972c25 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -65,6 +65,7 @@ def __init__( primary_registry: Optional[dict] = None, aot_autograd_decomposition: Optional[dict] = None, verbose=False, + enable_profile=False, enable_external_calls: bool = False, ) -> None: """ @@ -99,6 +100,7 @@ def __init__( self._func_name = func_name self._aot_autograd_decomposition = aot_autograd_decomposition self._verbose = verbose + self._enable_profile = enable_profile self._enable_external_calls = enable_external_calls self._imported_graphs = [] self._ops_registry = {} @@ -537,6 +539,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): self._func_name, DeviceType.CPU, self._verbose, + self._enable_profile, self._enable_external_calls, ) graph._params_ref = params_flat diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index f9f2021236..92ddae7232 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -29,6 +29,7 @@ from mlir.passmanager import * from mlir.execution_engine import * from mlir import runtime as rt +from mlir.dialects import arith, llvm from .operation import * from .type import * @@ -107,6 +108,7 @@ def __init__( func_name: str, device: DeviceType = DeviceType.CPU, verbose=False, + enable_profile: bool = False, enable_external_calls: bool = False, ) -> None: """ @@ -132,6 +134,7 @@ def __init__( self._imported_module = None self._params_ref = None self._verbose = verbose + self._enable_profile = enable_profile self._ops_registry = ops_registry self._func_name = func_name self._ctx = ir.Context() @@ -242,7 +245,6 @@ def displace_node(self, node: Op, newnode: Op): self.node_table.pop(node.name) self.node_table[newnode.name] = newnode - def displace_node_with_chain(self, node: Op, chain: list[Op]): """ Replaces an existing node with a chain of new nodes. @@ -250,7 +252,6 @@ def displace_node_with_chain(self, node: Op, chain: list[Op]): current node will have this node as their child instead of `node` - The last node is taken to be the "tail" of the chain, and all children of `node` will have this node as their parent instead. - Args: node (Op): The operation to be replaced. chain (list[Op]): The a list of nodes to be inserted instead of Op @@ -281,7 +282,7 @@ def displace_node_with_chain(self, node: Op, chain: list[Op]): node._children.clear() node_idx = self._body.index(node) - self._body = self.body[:node_idx] + chain + self.body[node_idx+1:] + self._body = self.body[:node_idx] + chain + self.body[node_idx + 1 :] def init_op_group(self): """ @@ -358,6 +359,7 @@ def lower_to_top_level_ir(self): False, self.device, verbose=self._verbose, + enable_profile=self._enable_profile, enable_external_calls=self._enable_external_calls, ) self._imported_module = fx_importer.import_graph() @@ -474,6 +476,7 @@ def __init__( do_param_pack: bool = False, device: DeviceType = DeviceType.CPU, verbose=False, + enable_profile=True, enable_external_calls: bool = False, ): """ @@ -495,6 +498,7 @@ def __init__( self._params = params self._inputs = inputs self._verbose = verbose + self._enable_profile = enable_profile self._do_param_pack = do_param_pack self._param_packs = [] self._num_input_visited = 0 @@ -571,7 +575,44 @@ def import_graph(self) -> ir.Module: mlir.ir.Module: An MLIR module in high-level dialects. """ assert self._do_param_pack == False + node_names = [node.name for node in self._body] with ir.InsertionPoint(self._module.body): + if self._enable_profile: + # Add clock func signature + f64 = ir.F64Type.get() + ptr = llvm.PointerType.get() + # func.func private @rtclock() -> f64 + func.FuncOp( + name="rtclock", + type=ir.FunctionType.get(inputs=[], results=[f64]), + visibility="private", + ) + # func.func private @record_timing(!llvm.ptr, f64) + func.FuncOp( + name="record_timing", + type=ir.FunctionType.get(inputs=[ptr, f64], results=[]), + visibility="private", + ) + + width = len(str(len(node_names))) + for i, name in enumerate(node_names): + sym_name = f"op_name_{i:0{width}d}_{name}" + content = sym_name + "\00" + content_type = ir.Type.parse( + f"!llvm.array<{len(content)} x i8>" + ) + linkage_attr = ir.Attribute.parse("#llvm.linkage") + llvm.GlobalOp( + content_type, + ir.StringAttr.get(sym_name), + linkage_attr, + constant=True, + value=ir.StringAttr.get(content), + addr_space=ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), 0 + ), # 关键字参数 + ) + arguments = [] inputs = self._params + self._inputs for arg in inputs: @@ -589,10 +630,21 @@ def import_graph(self) -> ir.Module: @func.FuncOp.from_py_func(*arguments, name=self._func_name) def generated_func(*args): args_list = list(args) - func_op = self._module.body.operations[0] - for node in self._body: + func_op_index = ( + len(node_names) + 2 if self._enable_profile else 0 + ) + func_op = self._module.body.operations[func_op_index] + for i, node in enumerate(self._body): if node in extern_func: continue + + if self._enable_profile: + # %t_start = call @rtclock() : () -> f64 + f64_type = ir.F64Type.get() + t_start = func.CallOp( + [f64_type], ir.FlatSymbolRefAttr.get("rtclock"), [] + ).results[0] + old_ops = [op for op in func_op.body.blocks[0].operations] if isinstance(node, OutputOp): output_node_args = node.args @@ -625,6 +677,31 @@ def generated_func(*args): print(op) print("") + if self._enable_profile: + # %t_end = call @rtclock() : () -> f64 + t_end = func.CallOp( + [f64_type], ir.FlatSymbolRefAttr.get("rtclock"), [] + ).results[0] + + # %duration = arith.subf %t_end, %t_start : f64 + duration = arith.SubFOp(t_end, t_start).result + + # %name_ptr = llvm.mlir.addressof @op_name_post_attn_layernorm : !llvm.ptr + ptr_type = llvm.PointerType.get() + name_ptr = llvm.AddressOfOp( + ptr_type, + ir.FlatSymbolRefAttr.get( + f"op_name_{i:0{width}d}_{node.name}" + ), + ).result + + # call @record_timing(%name_ptr, %duration) : (!llvm.ptr, f64) -> () + func.CallOp( + [], + ir.FlatSymbolRefAttr.get("record_timing"), + [name_ptr, duration], + ) + return self._symbol_table.get(("output", 0)) # Generate external function declarations for CallExternalOp nodes (only if enabled)