Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,907 changes: 0 additions & 2,907 deletions examples/BuddyLlama/op.txt

This file was deleted.

4,894 changes: 0 additions & 4,894 deletions examples/BuddyLlama/subgraph.mlir

This file was deleted.

2 changes: 1 addition & 1 deletion examples/LlamaTest/llama-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, input_ids, attention_mask, position_ids):
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]

driver = GraphDriver(graphs[0])
driver = GraphDriver(graphs[0], 2)
for i in range(len(driver.subgraphs)):
driver.subgraphs[i].lower_to_top_level_ir()
driver.construct_main_graph(True)
Expand Down
2 changes: 1 addition & 1 deletion examples/SplitLlama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ add_custom_command(
COMMENT "Building subgraph193.o "
VERBATIM)

set(Boost_INCLUDE_DIR "/home/chenweiwei/boost_1_86_0/include")
# set(Boost_INCLUDE_DIR "/home/chenweiwei/boost_1_86_0/include")
add_library(SPLITLLAMA5 STATIC forward193.o subgraph193.o)

# 查找 Boost 库
Expand Down
2 changes: 1 addition & 1 deletion examples/SplitLlama/llama-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, input_ids, attention_mask, position_ids):
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]

driver = GraphDriver(graphs[0])
driver = GraphDriver(graphs[0], 2)
for i in range(len(driver.subgraphs)):
driver.subgraphs[i].lower_to_top_level_ir()

Expand Down
20 changes: 14 additions & 6 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class OutputDescriptor(ctypes.Structure):

return OutputDescriptor


# Graph类, 表示Buddy编译器前端的图级表达式。
class Graph:
"""
Graph is a graph-level expression for the Buddy Compiler frontends.
Expand Down Expand Up @@ -136,7 +136,8 @@ def __init__(
self._output_memref = None
self._output_descriptor = None
self.execution_engine = None
self.paral_group: Dict[str, List[int]] = {}
self.op_groups: Dict[str, List[Op]] = {}
self.group_map_device: Dict[str, DeviceType] = {}

@property
def body(self):
Expand Down Expand Up @@ -514,8 +515,7 @@ def addsymbol(self) -> None:
"""
for key, value in self._symbol_table.items():
print(f"Key: {key}, Value: {value}")



def import_graph(self) -> ir.Module:
"""
Imports buddy graph and generates an MLIR module in high-level dialects.
Expand All @@ -524,6 +524,7 @@ def import_graph(self) -> ir.Module:
mlir.ir.Module: An MLIR module in high-level dialects.
"""
assert self._do_param_pack == False
# 创建一个新Module, 根据计算图中的算子添加对应的MLIR操作
with ir.InsertionPoint(self._module.body):
arguments = []
inputs = self._params + self._inputs
Expand All @@ -538,7 +539,7 @@ def import_graph(self) -> ir.Module:
if isinstance(node, FuncOp):
extern_func.append(node)
self._import_op(node)

#
@func.FuncOp.from_py_func(*arguments, name=self._func_name)
def generated_func(*args):
args_list = list(args)
Expand Down Expand Up @@ -590,6 +591,7 @@ def import_main_graph(self) -> ir.Module:
Returns:
mlir.ir.Module: An MLIR module in high-level dialects.
"""
# 创建一个新Module, 根据计算图中的算子添加对应的MLIR操作
with ir.InsertionPoint(self._module.body):
arguments = []
if self._do_param_pack:
Expand All @@ -610,10 +612,13 @@ def import_main_graph(self) -> ir.Module:
extern_func.append(node)
self._import_op(node)

# 将下方的Python函数包装为MLIR中的函数操作(FuncOp)
@func.FuncOp.from_py_func(*arguments, name=self._func_name)
def generated_func(*args):
args_list = list(args)
# 遍历计算图的节点进行针对性处理
for node in self._body:
# 外部函数无需处理
if node in extern_func:
continue
if isinstance(node, OutputOp):
Expand All @@ -628,14 +633,15 @@ def generated_func(*args):
node.tensor_meta['shape'] = torch.Size(list(node._newshape))
self._import_placeholder(node, args_list)
elif isinstance(node, GetItemOp):
# print(self._symbol_table)
self._symbol_table[(str(node.name), 0)] = (
self._symbol_table[
(str(node.args[0]), node.args[1])
]
)
else:
self._import_op(node)

return self._symbol_table.get(("output", 0))

return self._module
Expand Down Expand Up @@ -693,9 +699,11 @@ def _import_op(self, node: Op):

"""
op_name = node.__class__.__name__
# 根据算子类型自动调用MLIR操作注册表中的对应函数生成MLIR操作
op_ret: ir.Operation | ir.Value | tuple | List | ir.OpResult = (
self._ops_registry[op_name](node, self._symbol_table)
)
# 根据返回值类型将MLIR操作结果添加到符号表中
if isinstance(op_ret, tuple | List):
for i, operation in enumerate(op_ret):
if isinstance(operation, ir.Operation) or isinstance(
Expand Down
Loading
Loading