Fix multimodal hidden-state preparation for Qwen3-VL models#526
Fix multimodal hidden-state preparation for Qwen3-VL models#526liusy58 wants to merge 2 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for vision-language models (VLM) by adding pixel_values and image_grid_thw to the DataPoint class and the data generation pipeline. It also updates the torch_dtype resolution logic to be more flexible. Feedback focuses on preventing KeyError and TypeError exceptions when these new fields are missing in non-VLM contexts, as well as simplifying the conditional logic for determining the model's data type.
scripts/prepare_hidden_states.py
Outdated
| pixel_values=batch["pixel_values"][valid_indices_in_batch] | ||
| image_grid_thw=batch["image_grid_thw"][valid_indices_in_batch] |
There was a problem hiding this comment.
This implementation will raise a KeyError for non-VLM models because pixel_values and image_grid_thw are missing from the batch. Additionally, these tensors must be added to filtered_batch so they are correctly moved to the GPU and passed to the model's extend method during hidden state generation.
pixel_values = batch.get("pixel_values")
if pixel_values is not None:
pixel_values = pixel_values[valid_indices_in_batch]
filtered_batch["pixel_values"] = pixel_values
image_grid_thw = batch.get("image_grid_thw")
if image_grid_thw is not None:
image_grid_thw = image_grid_thw[valid_indices_in_batch]
filtered_batch["image_grid_thw"] = image_grid_thw| pixel_values=pixel_values[i].clone(), | ||
| image_grid_thw=image_grid_thw[i].clone(), |
There was a problem hiding this comment.
Accessing pixel_values[i] and image_grid_thw[i] will cause a TypeError or AttributeError if the variables are None (which occurs for non-VLM models).
| pixel_values=pixel_values[i].clone(), | |
| image_grid_thw=image_grid_thw[i].clone(), | |
| pixel_values=pixel_values[i].clone() if pixel_values is not None else None, | |
| image_grid_thw=image_grid_thw[i].clone() if image_grid_thw is not None else None, |
| if hasattr(model_config, "dtype") and model_config.dtype is not None: | ||
| torch_dtype = model_config.dtype | ||
| elif hasattr(model_config, "text_config") and hasattr(model_config.text_config, "dtype"): | ||
| torch_dtype = model_config.text_config.dtype | ||
| else: | ||
| torch_dtype = getattr(model_config, "torch_dtype", "bfloat16") |
There was a problem hiding this comment.
The logic for determining torch_dtype can be simplified and made more robust by using getattr with defaults and explicitly checking for None values in nested configurations.
| if hasattr(model_config, "dtype") and model_config.dtype is not None: | |
| torch_dtype = model_config.dtype | |
| elif hasattr(model_config, "text_config") and hasattr(model_config.text_config, "dtype"): | |
| torch_dtype = model_config.text_config.dtype | |
| else: | |
| torch_dtype = getattr(model_config, "torch_dtype", "bfloat16") | |
| torch_dtype = getattr(model_config, "dtype", None) | |
| if torch_dtype is None and hasattr(model_config, "text_config"): | |
| torch_dtype = getattr(model_config.text_config, "dtype", None) | |
| if torch_dtype is None: | |
| torch_dtype = getattr(model_config, "torch_dtype", "bfloat16") |
Motivation
This PR updates scripts/prepare_hidden_states.py to better support multimodal target models during hidden-state preparation.
Modifications
Related Issues
Accuracy Test
Benchmark & Profiling
Checklist