Skip to content

Fix VLM preprocessing and add mRoPE position handling in target head#527

Open
liusy58 wants to merge 2 commits intosgl-project:mainfrom
liusy58:fix_train_vlm
Open

Fix VLM preprocessing and add mRoPE position handling in target head#527
liusy58 wants to merge 2 commits intosgl-project:mainfrom
liusy58:fix_train_vlm

Conversation

@liusy58
Copy link
Copy Markdown

@liusy58 liusy58 commented Apr 8, 2026

Motivation

This PR improves VLM training support in two areas:

  • preserve dataset-provided system prompts during VLM conversation preprocessing
  • carry image tensors through data processing for downstream model usage
  • add multimodal RoPE position index generation in the target head

Modifications

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

@liusy58 liusy58 changed the title fix Fix VLM preprocessing and add mRoPE position handling in target head Apr 8, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates VLM conversation preprocessing to support custom system prompts and ensures vision-related data fields are preserved during processing. It also introduces a get_rope_index method for calculating multimodal rotary position embeddings. Feedback includes addressing a potential IndexError in conversation handling, optimizing token search efficiency, and refactoring duplicated logic in the position ID calculation.

for i, image in enumerate(examples["image"]):
source = examples["conversations"][i]
messages = [{"role": "system", "content": system_prompt}]
if source[0]["role"] == "system":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential IndexError here. If source (which is examples["conversations"][i]) is an empty list, accessing source[0] will raise an exception. The check for an empty source on line 224 happens after this access.

To prevent this, you should add a check to ensure source is not empty before accessing its first element.

Suggested change
if source[0]["role"] == "system":
if source and source[0]["role"] == "system":

Comment on lines +148 to +149
ed_image = input_tokens.index(image_token_id, st) if image_token_id in input_tokens[st:] and remain_images > 0 else len(input_tokens) + 1
ed_video = input_tokens.index(video_token_id, st) if video_token_id in input_tokens[st:] and remain_videos > 0 else len(input_tokens) + 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for finding ed_image and ed_video uses an in check followed by .index(), which results in scanning the input_tokens list twice for each token type in each iteration. This can be inefficient for long sequences.

A more efficient and Pythonic approach is to use a try-except block to handle cases where a token is not found. This avoids the redundant scan.

Suggested change
ed_image = input_tokens.index(image_token_id, st) if image_token_id in input_tokens[st:] and remain_images > 0 else len(input_tokens) + 1
ed_video = input_tokens.index(video_token_id, st) if video_token_id in input_tokens[st:] and remain_videos > 0 else len(input_tokens) + 1
ed_image = len(input_tokens) + 1
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
pass
ed_video = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
pass

Comment on lines +175 to +178
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in this block to handle trailing text is very similar to the logic for handling text between vision tokens (lines 163-166). Specifically, calculating st_idx and appending text position IDs is duplicated.

To improve maintainability and reduce redundancy, consider refactoring this repeated logic. You could potentially merge this trailing text handling into the main loop or use a helper function to encapsulate the logic for appending text positions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant