Fix VLM preprocessing and add mRoPE position handling in target head#527
Fix VLM preprocessing and add mRoPE position handling in target head#527liusy58 wants to merge 2 commits intosgl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates VLM conversation preprocessing to support custom system prompts and ensures vision-related data fields are preserved during processing. It also introduces a get_rope_index method for calculating multimodal rotary position embeddings. Feedback includes addressing a potential IndexError in conversation handling, optimizing token search efficiency, and refactoring duplicated logic in the position ID calculation.
| for i, image in enumerate(examples["image"]): | ||
| source = examples["conversations"][i] | ||
| messages = [{"role": "system", "content": system_prompt}] | ||
| if source[0]["role"] == "system": |
There was a problem hiding this comment.
There's a potential IndexError here. If source (which is examples["conversations"][i]) is an empty list, accessing source[0] will raise an exception. The check for an empty source on line 224 happens after this access.
To prevent this, you should add a check to ensure source is not empty before accessing its first element.
| if source[0]["role"] == "system": | |
| if source and source[0]["role"] == "system": |
| ed_image = input_tokens.index(image_token_id, st) if image_token_id in input_tokens[st:] and remain_images > 0 else len(input_tokens) + 1 | ||
| ed_video = input_tokens.index(video_token_id, st) if video_token_id in input_tokens[st:] and remain_videos > 0 else len(input_tokens) + 1 |
There was a problem hiding this comment.
The current implementation for finding ed_image and ed_video uses an in check followed by .index(), which results in scanning the input_tokens list twice for each token type in each iteration. This can be inefficient for long sequences.
A more efficient and Pythonic approach is to use a try-except block to handle cases where a token is not found. This avoids the redundant scan.
| ed_image = input_tokens.index(image_token_id, st) if image_token_id in input_tokens[st:] and remain_images > 0 else len(input_tokens) + 1 | |
| ed_video = input_tokens.index(video_token_id, st) if video_token_id in input_tokens[st:] and remain_videos > 0 else len(input_tokens) + 1 | |
| ed_image = len(input_tokens) + 1 | |
| if remain_images > 0: | |
| try: | |
| ed_image = input_tokens.index(image_token_id, st) | |
| except ValueError: | |
| pass | |
| ed_video = len(input_tokens) + 1 | |
| if remain_videos > 0: | |
| try: | |
| ed_video = input_tokens.index(video_token_id, st) | |
| except ValueError: | |
| pass |
| if st < len(input_tokens): | ||
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||
| text_len = len(input_tokens) - st | ||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
There was a problem hiding this comment.
The logic in this block to handle trailing text is very similar to the logic for handling text between vision tokens (lines 163-166). Specifically, calculating st_idx and appending text position IDs is duplicated.
To improve maintainability and reduce redundancy, consider refactoring this repeated logic. You could potentially merge this trailing text handling into the main loop or use a helper function to encapsulate the logic for appending text positions.
Motivation
This PR improves VLM training support in two areas:
Modifications
Related Issues
Accuracy Test
Benchmark & Profiling
Checklist