Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions dynamic-input-notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
目前已经成功将 prefill 的 input 设置为动态形状,同时修改部分前端算子生成的逻辑以适配该形状。(但需要注意的是,这些修改的正确性目前无法验证)

现在遇到的问题是,在编译`subgraph0_prefill.o`时,会有报错:

```text
[1/6] Building subgraph_prefill-f16.o
/home/cyanic/repos/buddy-mlir/build/examples/BuddyDeepSeekR1/subgraph0_prefill-f16.mlir:136:11: error: 'tosa.mul' op operands don't have matching ranks
%95 = "tosa.mul"(%arg3, %94) : (tensor<1536xf16>, tensor<1x?x1536xf16>) -> tensor<1x?x1536xf16>
^
/home/cyanic/repos/buddy-mlir/build/examples/BuddyDeepSeekR1/subgraph0_prefill-f16.mlir:136:11: note: see current operation: %95 = "tosa.mul"(%arg3, %94) : (tensor<1536xf16>, tensor<1x?x1536xf16>) -> tensor<1x?x1536xf16>
```

可以看到,动态形状已经从 forward_prefill 传导到其他阶段了。仔细看这个报错,可以发现其来自于 `tosa.mul` 的期望参数类型和实际传入参数类型的不一致。而产生这个不一致的原因,我觉得是四个 mlir 文件的*分离生成*,以及相关 tosa 算子没有正确适配动态形状。
17 changes: 14 additions & 3 deletions examples/BuddyDeepSeekR1/buddy-deepseek-r1-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ constexpr float RopeTheta = 10000.0f;

using RopeFreqArray = std::array<float, HiddenSize / 2>;

// ANSI Color Codes
constexpr const char* COLOR_RESET = "\033[0m";
constexpr const char* COLOR_BLUE = "\033[34;1m";
constexpr const char* COLOR_YELLOW = "\033[33;1m";

struct MemRefContainer {
MemRef<float, 4> kv0;
MemRef<float, 4> kv1;
Expand Down Expand Up @@ -581,9 +586,10 @@ GenerationResult runGeneration(const std::string &prompt,
const std::chrono::duration<double, std::milli> prefillMs =
prefillEnd - prefillStart;
const double prefillSeconds = prefillMs.count() / 1000.0;
const size_t actualPrefillTokens = inputContainerPrefill.getTokenCnt();
if (prefillSeconds > 0.0) {
stats.prefillTokensPerSec =
static_cast<double>(MaxTokenLength) / prefillSeconds;
static_cast<double>(actualPrefillTokens) / prefillSeconds;
}

std::string streamed;
Expand All @@ -606,15 +612,20 @@ GenerationResult runGeneration(const std::string &prompt,

// Copy KV cache from prefill to decode.
getInfoStream() << "[Debug] Copying KV cache...\n";
const auto copyStart = std::chrono::high_resolution_clock::now();
copyKVByCachePositionBlock(prefillResult, decodeResult,
inputContainerPrefill.getTokenCnt());
const auto copyEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> copyMs =
copyEnd - copyStart;
const double copySeconds = copyMs.count() / 1000.0;
getInfoStream() << "[Debug] KV cache copy finished.\n";

cachePosition.getData()[0] = inputContainerPrefill.getTokenCnt();
inputContainerDecode.getData()[0] = static_cast<long long>(maxIndex);
if (maxIndex == eosTokenId) {
tokenStream << std::endl;
stats.totalSeconds = prefillSeconds;
stats.totalSeconds = prefillSeconds + copySeconds;
stats.finalText = streamed;
return stats;
}
Expand Down Expand Up @@ -725,7 +736,7 @@ GenerationResult runGeneration(const std::string &prompt,
tokenStream << std::endl;
stats.generatedTokens = outputContainer.getTokenCnt();
stats.finalText = streamed;
stats.totalSeconds = prefillSeconds + decodeSeconds;
stats.totalSeconds = prefillSeconds + copySeconds + decodeSeconds;
return stats;
}

Expand Down
42 changes: 32 additions & 10 deletions examples/BuddyDeepSeekR1/buddy-deepseek-r1-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ constexpr size_t NUM_LAYERS = 56;
constexpr size_t HiddenSize = 128;
constexpr size_t HeadNum = 2;

// ANSI Color Codes
constexpr const char* COLOR_RESET = "\033[0m";
constexpr const char* COLOR_BLUE = "\033[34;1m";
constexpr const char* COLOR_GREEN = "\033[32;1m";
constexpr const char* COLOR_YELLOW = "\033[33;1m";

struct MemRefContainer {

MemRef<float, 4> kv0;
Expand Down Expand Up @@ -181,12 +187,13 @@ void getUserInput(std::string &inputStr) {
}

/// Print [Log] label in bold blue format.
void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; }
void printLogLabel() {
std::cout << COLOR_BLUE << "[Log] " << COLOR_RESET;
}

/// Print information for each iteration.
void printIterInfo(size_t iterIdx, std::string str, double time) {
total_time += time;
std::cout << "\033[32;1m[Iteration " << iterIdx << "] \033[0m";
std::cout << COLOR_GREEN << "[Iteration " << iterIdx << "] " << COLOR_RESET;
std::cout << "Token: " << str << " | "
<< "Time: " << time << "s" << std::endl;
}
Expand Down Expand Up @@ -269,7 +276,7 @@ void copy_kv_by_cache_position_block(const MemRefContainer &prefill,
int main() {
/// Print the title of this example.
const std::string title = "DeepSeekR1 Inference Powered by Buddy Compiler";
std::cout << "\033[33;1m" << title << "\033[0m" << std::endl;
std::cout << COLOR_YELLOW << title << COLOR_RESET << std::endl;

/// Define directories of vacabulary and parameter file.
std::string deepSeekR1Dir = DEEPSEEKR1_EXAMPLE_PATH;
Expand Down Expand Up @@ -393,9 +400,11 @@ int main() {
int maxIndex = findMaxIndex(startPtr, endPtr);
std::string tok = inputContainerPrefill.getStr(maxIndex);
printIterInfo(0, tok, inferenceTime.count() / 1000);
total_time += inferenceTime.count() / 1000.0;
const double prefillSeconds = inferenceTime.count() / 1000.0;
size_t actualPrefillTokens = inputContainerPrefill.getTokenCnt();
if (prefillSeconds > 0.0) {
prefillTokensPerSec = static_cast<double>(MaxTokenLength) / prefillSeconds;
prefillTokensPerSec = static_cast<double>(actualPrefillTokens) / prefillSeconds;
}
inputContainerDecode.getData()[0] = (long long)maxIndex;
outputContainer.appendTokenIdx(maxIndex);
Expand All @@ -411,8 +420,13 @@ int main() {

MemRefContainer *ptrDecodeResultContainer = &decodeResultContainer;

const auto copyStart = std::chrono::high_resolution_clock::now();
copy_kv_by_cache_position_block(prefillResultContainer, decodeResultContainer,
inputContainerPrefill.getTokenCnt());
const auto copyEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> copyTime =
copyEnd - copyStart;
total_time += copyTime.count() / 1000.0;

cachePosition.getData()[0] = inputContainerPrefill.getTokenCnt();
int generateLen = MaxTokenLength - inputContainerPrefill.getTokenCnt();
Expand Down Expand Up @@ -476,19 +490,27 @@ int main() {
cachePosition.getData()[0] += 1;
}

total_time += decodeTimeAccumMs / 1000.0;

const double decodeSeconds = decodeTimeAccumMs / 1000.0;
const double decodeTokensPerSec =
decodeSeconds > 0.0 ? static_cast<double>(decodeTokens) / decodeSeconds
: 0.0;

/// Print the final result
std::cout << "\n\033[33;1m[Total time]\033[0m " << total_time << std::endl;
std::cout << "\033[33;1m[Prefilling]\033[0m " << prefillTokensPerSec
const double prefillTime = prefillSeconds;
const double decodeTime = decodeTimeAccumMs / 1000.0;
const double copyTimeSec = copyTime.count() / 1000.0;
std::cout << "\n" << COLOR_YELLOW << "[Total time]" << COLOR_RESET << " " << total_time << "s" << std::endl;
std::cout << COLOR_YELLOW << "[Prefill time]" << COLOR_RESET << " " << prefillTime << "s" << std::endl;
std::cout << COLOR_YELLOW << "[KV cache copy time]" << COLOR_RESET << " " << copyTimeSec << "s" << std::endl;
std::cout << COLOR_YELLOW << "[Decode time]" << COLOR_RESET << " " << decodeTime << "s" << std::endl;
std::cout << COLOR_YELLOW << "[Prefilling]" << COLOR_RESET << " " << prefillTokensPerSec
<< " tokens/s" << std::endl;
std::cout << "\033[33;1m[Decoding]\033[0m " << decodeTokensPerSec
std::cout << COLOR_YELLOW << "[Decoding]" << COLOR_RESET << " " << decodeTokensPerSec
<< " tokens/s" << std::endl;
std::cout << "\033[33;1m[Input]\033[0m " << inputStr << std::endl;
std::cout << "\033[33;1m[Output]\033[0m "
std::cout << COLOR_YELLOW << "[Input]" << COLOR_RESET << " " << inputStr << std::endl;
std::cout << COLOR_YELLOW << "[Output]" << COLOR_RESET << " "
<< outputContainer.revertDeepSeekR1() << std::endl;

return 0;
Expand Down
26 changes: 19 additions & 7 deletions examples/BuddyDeepSeekR1/import-deepseek-r1.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
choices=["f32", "f16", "bf16"],
help="Precision mode for generated MLIR and input data. Choose from 'f32', 'f16', or 'bf16'.",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=128,
help="Maximum sequence length for f16/bf16 prefill. For f32, defaults to 1024.",
)
args = parser.parse_args()

# Ensure the output directory exists.
Expand Down Expand Up @@ -94,6 +100,7 @@
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
func_name="forward_prefill",
dynamic_dims={0: [1]}, # input_ids (input parameter, not cache) has dynamic sequence length
)

dynamo_compiler_decode = DynamoCompiler(
Expand All @@ -117,15 +124,19 @@
# Import the model into MLIR module and parameters.
with torch.no_grad():
if args.precision == "f16":
past_key_values_prefill = StaticCache(
config=model.config, max_cache_len=20
)
max_seq_len = args.max_seq_len
# Don't initialize cache for prefill since it's not used (commented out below)
# past_key_values_prefill = StaticCache(
# config=model.config, max_cache_len=max_seq_len
# )
past_key_values_decode = StaticCache(
config=model.config, max_cache_len=20
config=model.config, max_cache_len=max_seq_len
)

# Create dynamic shape tensor for prefill
# Using a sample tensor, but torch.export will trace with dynamic dimensions
data_prefill = {
"input_ids": torch.zeros((1, 20), dtype=torch.int64),
"input_ids": torch.zeros((1, max_seq_len), dtype=torch.int64),
}
data_decode = {
"input_ids": torch.zeros((1, 1), dtype=torch.int64),
Expand All @@ -136,9 +147,9 @@
graphs_prefill = dynamo_compiler_prefill.importer(
model,
input_ids=data_prefill["input_ids"],
use_cache=True,
# past_key_values=past_key_values_prefill,
use_cache=False, # Don't use cache for prefill
cache_implementation="static",
dynamic=True
)
# Initialize past_key_values once during the first forward call
model(
Expand All @@ -155,6 +166,7 @@
cache_position=cache_position,
past_key_values=past_key_values_decode,
cache_implementation="static",
dynamic=True
)
else:
past_key_values_prefill = StaticCache(
Expand Down
48 changes: 43 additions & 5 deletions frontend/Python/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
aot_autograd_decomposition: Optional[dict] = None,
verbose=False,
enable_external_calls: bool = False,
dynamic_dims: Optional[dict] = None,
) -> None:
"""
Initializes the Dynamo Compiler.
Expand All @@ -79,6 +80,8 @@ def __init__(
debugging purposes. The default value is False, indicating that
no extra debug information will be printed.
enable_external_calls (bool): Enable external function call support (for oneDNN, etc.)
dynamic_dims (dict, optional): Maps input indices to lists of dynamic dimension indices.
Example: {0: [1]} means input 0's dimension 1 is dynamic.
Attributes:
_func_name: The function name to be used.
_aot_autograd_decomposition (Optional[dict], optional):
Expand All @@ -100,6 +103,7 @@ def __init__(
self._aot_autograd_decomposition = aot_autograd_decomposition
self._verbose = verbose
self._enable_external_calls = enable_external_calls
self._dynamic_dims = dynamic_dims if dynamic_dims is not None else {}
self._imported_graphs = []
self._ops_registry = {}
self._imported_params = {}
Expand Down Expand Up @@ -267,7 +271,13 @@ def _create_node(
if node_kwargs is None:
node_kwargs = {}
buddy_node._keyword_arguments.update(node_kwargs)
buddy_node._tensor_meta["shape"] = node_output_shape

# Convert None to MLIR dynamic size marker for tensor metadata
converted_shape = [
ir.ShapedType.get_dynamic_size() if dim is None else dim
for dim in node_output_shape
]
buddy_node._tensor_meta["shape"] = converted_shape
buddy_node._tensor_meta["dtype"] = node_output_dtype
return buddy_node

Expand Down Expand Up @@ -318,12 +328,16 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
if self._model_config.decode_with_cache:
num_cached_kv = self._model_config.num_hidden_layers * 2
func_inputs = []
for i in inputs_pos:
for input_idx, i in enumerate(inputs_pos):
# for inp in _inputs[len(params_flat) :]:
inp = _inputs[i + num_cached_kv]
inp_shape = inp.shape
inp_shape = list(inp.shape)
# Apply dynamic dimensions if specified
if input_idx in self._dynamic_dims:
for dim_idx in self._dynamic_dims[input_idx]:
inp_shape[dim_idx] = None # Mark as dynamic
inp_dtype = self._torch_dtype_translate(str(inp.dtype))
func_inputs.append(TensorMeta(inp_shape, inp_dtype))
func_inputs.append(TensorMeta(tuple(inp_shape), inp_dtype))
for inp in _inputs[:num_cached_kv]:
inp = _inputs[i]
inp_shape = inp.shape
Expand Down Expand Up @@ -361,6 +375,21 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
input_nodes.extend(list(_gm.graph.nodes)[:num_cached_kv])
gm_nodes = param_nodes + buffers_nodes + input_nodes + other_nodes

# Create a mapping from node to input index for dynamic dim lookup
# Only include nodes that are actual user inputs (in inputs_pos), not parameters/buffers
node_to_input_idx = {}
for node in input_nodes:
# Find this node's position in inputs_pos
try:
node_idx = list(_gm.graph.nodes).index(node)
if node_idx in inputs_pos:
input_idx = inputs_pos.index(node_idx)
node_to_input_idx[node] = input_idx
except (ValueError, AttributeError):
# Node might not be in graph.nodes or nodes might not be indexable
# Fall back to using enumerate index for input_nodes
pass

for gm_node in gm_nodes:
node_users = []
for user in gm_node.users.keys():
Expand All @@ -369,12 +398,21 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]):
node_dtype = self._torch_dtype_translate(
str(gm_node.meta["tensor_meta"].dtype)
)

# Override shape with dynamic dimensions if specified
node_shape = list(gm_node.meta["tensor_meta"].shape)
if gm_node in node_to_input_idx:
input_idx = node_to_input_idx[gm_node]
if input_idx in self._dynamic_dims:
for dim_idx in self._dynamic_dims[input_idx]:
node_shape[dim_idx] = None # Mark as dynamic

buddy_node = self._create_node(
gm_node.op,
gm_node.name,
gm_node.args,
node_users,
gm_node.meta["tensor_meta"].shape,
node_shape,
node_dtype,
)

Expand Down
13 changes: 11 additions & 2 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,12 @@ def import_graph(self) -> ir.Module:
inputs = self._params + self._inputs
for arg in inputs:
shape_list = list(arg.shape)
dtype = arg.dtype
mlir_dtype = self._str_to_mlir_dtype(dtype)
# Convert None or -1 to MLIR dynamic dimension marker
shape_list = [
ir.ShapedType.get_dynamic_size() if (dim is None or dim == -1) else dim
for dim in shape_list
]
mlir_dtype = self._str_to_mlir_dtype(arg.dtype)
tensor_arg = ir.RankedTensorType.get(shape_list, mlir_dtype)
arguments.append(tensor_arg)
extern_func = []
Expand Down Expand Up @@ -614,6 +618,11 @@ def import_main_graph(self) -> ir.Module:
inputs = self._params + self._inputs
for arg in inputs:
shape_list = list(arg.shape)
# Convert None or -1 to MLIR dynamic dimension marker
shape_list = [
ir.ShapedType.get_dynamic_size() if (dim is None or dim == -1) else dim
for dim in shape_list
]
dtype = arg.dtype
mlir_dtype = self._str_to_mlir_dtype(dtype)
tensor_arg = ir.MemRefType.get(shape_list, mlir_dtype)
Expand Down
Loading