Skip to content

[Bug] --max_seq_len ignored for sliding window models with --use_custom_kv_cache --use_custom_sdpa (cache capped at sliding_window size) #218

@rhn19

Description

@rhn19

Environment

  • optimum-executorch: main
  • Model: google/gemma-3-1b-it
  • Python: 3.12
  • Ubuntu 24.04

Setup

git clone https://github.com/pytorch/executorch.git
pushd executorch
git submodule update --init --recursive
bash install_requirements.sh
popd

git clone https://github.com/huggingface/optimum-executorch.git
pushd optimum-executorch
python install_dev.py --skip_override_torch
popd

pip install triton

Export (succeeds)

optimum-cli export executorch \
    --model "google/gemma-3-1b-it" \
    --task "text-generation" \
    --recipe "xnnpack" \
    --use_custom_sdpa \
    --use_custom_kv_cache \
    --qlinear 8da4w \
    --qembedding 8w \
    --max_seq_len 1024 \
    --dtype "float32" \
    --device "cpu" \
    --output_dir "gemma3_1b_export"

Reproduction (validate_export.py)

from optimum.executorch import ExecuTorchModelForCausalLM
from transformers import AutoTokenizer

model = ExecuTorchModelForCausalLM.from_pretrained("gemma3_1b_export")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

# Simulate a prompt exceeding 512 tokens
long_prompt = " ".join(["hello"] * 750)

generated_text = model.text_generation(
    tokenizer=tokenizer,
    prompt=long_prompt,
    max_seq_len=1024,
)
print(generated_text)

tokens = tokenizer.encode(long_prompt)
print("Number of tokens:", len(tokens))

python validate_export.py

Error

[tensor_impl.cpp:129] Attempted to resize a bounded tensor with a maximum capacity of 511 elements to 751 elements.
[method.cpp:1136] Error resizing tensor at input 0
Exception: Failed to execute method forward, error: 0x10.
MethodMeta(name='forward', num_inputs=2, input_tensor_meta=['TensorInfo(sizes=[1, 511], dtype=Long, ...)' 'TensorInfo(sizes=[511], dtype=Long, ...)'], num_outputs=1)
arg shapes: {'input_ids': torch.Size([1, 751]), 'cache_position': torch.Size([751])}
RuntimeError: Failed to execute method forward, error: 0x10

The export succeeds and short prompts (< 512 tokens) work fine. The error only surfaces at runtime when the prompt exceeds 512 tokens (the model's sliding_window size) despite --max_seq_len 1024.

Root cause

Two issues in CausalLMExportableModule in optimum/exporters/executorch/integrations.py:

_prepare_export_inputs: max_dim is computed as min(max_seq_len, sliding_window) - 1 = 511, so torch.export bakes a <= 511 guard on the sequence length dimension regardless of --max_seq_len.

export: TorchExportableModuleWithHybridCache.init calls StaticCache, which internally creates StaticSlidingWindowLayer(max_cache_len=max_seq_len, sliding_window=512). That class sets effective_max_cache_len = min(sliding_window, max_cache_len) = 512, and its get_mask_sizes() returns kv_length = 512 as a constant during tracing — baking a second <= 512 guard into the exported graph.

Fix

See linked PR.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions