Skip to content

Fix multimodal hidden-state preparation for Qwen3-VL models#526

Open
liusy58 wants to merge 2 commits intosgl-project:mainfrom
liusy58:fix_prepare_hidden_states
Open

Fix multimodal hidden-state preparation for Qwen3-VL models#526
liusy58 wants to merge 2 commits intosgl-project:mainfrom
liusy58:fix_prepare_hidden_states

Conversation

@liusy58
Copy link
Copy Markdown

@liusy58 liusy58 commented Apr 8, 2026

Motivation

This PR updates scripts/prepare_hidden_states.py to better support multimodal target models during hidden-state preparation.

Modifications

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for vision-language models (VLM) by adding pixel_values and image_grid_thw to the DataPoint class and the data generation pipeline. It also updates the torch_dtype resolution logic to be more flexible. Feedback focuses on preventing KeyError and TypeError exceptions when these new fields are missing in non-VLM contexts, as well as simplifying the conditional logic for determining the model's data type.

Comment on lines +496 to +497
pixel_values=batch["pixel_values"][valid_indices_in_batch]
image_grid_thw=batch["image_grid_thw"][valid_indices_in_batch]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation will raise a KeyError for non-VLM models because pixel_values and image_grid_thw are missing from the batch. Additionally, these tensors must be added to filtered_batch so they are correctly moved to the GPU and passed to the model's extend method during hidden state generation.

            pixel_values = batch.get("pixel_values")
            if pixel_values is not None:
                pixel_values = pixel_values[valid_indices_in_batch]
                filtered_batch["pixel_values"] = pixel_values
            image_grid_thw = batch.get("image_grid_thw")
            if image_grid_thw is not None:
                image_grid_thw = image_grid_thw[valid_indices_in_batch]
                filtered_batch["image_grid_thw"] = image_grid_thw

Comment on lines +552 to +553
pixel_values=pixel_values[i].clone(),
image_grid_thw=image_grid_thw[i].clone(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing pixel_values[i] and image_grid_thw[i] will cause a TypeError or AttributeError if the variables are None (which occurs for non-VLM models).

Suggested change
pixel_values=pixel_values[i].clone(),
image_grid_thw=image_grid_thw[i].clone(),
pixel_values=pixel_values[i].clone() if pixel_values is not None else None,
image_grid_thw=image_grid_thw[i].clone() if image_grid_thw is not None else None,

Comment on lines +192 to +197
if hasattr(model_config, "dtype") and model_config.dtype is not None:
torch_dtype = model_config.dtype
elif hasattr(model_config, "text_config") and hasattr(model_config.text_config, "dtype"):
torch_dtype = model_config.text_config.dtype
else:
torch_dtype = getattr(model_config, "torch_dtype", "bfloat16")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining torch_dtype can be simplified and made more robust by using getattr with defaults and explicitly checking for None values in nested configurations.

Suggested change
if hasattr(model_config, "dtype") and model_config.dtype is not None:
torch_dtype = model_config.dtype
elif hasattr(model_config, "text_config") and hasattr(model_config.text_config, "dtype"):
torch_dtype = model_config.text_config.dtype
else:
torch_dtype = getattr(model_config, "torch_dtype", "bfloat16")
torch_dtype = getattr(model_config, "dtype", None)
if torch_dtype is None and hasattr(model_config, "text_config"):
torch_dtype = getattr(model_config.text_config, "dtype", None)
if torch_dtype is None:
torch_dtype = getattr(model_config, "torch_dtype", "bfloat16")

@liusy58 liusy58 changed the title fix Fix multimodal hidden-state preparation for Qwen3-VL models Apr 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant