Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,11 @@ def preprocess_vlm_conversations(
# Note: currently, we assume that each example has only one image
for i, image in enumerate(examples["image"]):
source = examples["conversations"][i]
messages = [{"role": "system", "content": system_prompt}]
if source[0]["role"] == "system":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential IndexError here. If source (which is examples["conversations"][i]) is an empty list, accessing source[0] will raise an exception. The check for an empty source on line 224 happens after this access.

To prevent this, you should add a check to ensure source is not empty before accessing its first element.

Suggested change
if source[0]["role"] == "system":
if source and source[0]["role"] == "system":

messages = [{"role": "system", "content": source[0]["content"]}]
source = source[1:]
else:
messages = [{"role": "system", "content": system_prompt}]
if not source:
# if the source is None, skip it
continue
Expand Down Expand Up @@ -533,6 +537,12 @@ def process_data(data, max_len, transform=None):
new_data["target"] = target
new_data["hidden_state"] = hidden_state
new_data["input_ids"] = input_ids

if "pixel_values" in data:
new_data["pixel_values"] = data["pixel_values"]
if "image_grid_thw" in data:
new_data["image_grid_thw"] = data["image_grid_thw"]

if transform:
new_data = transform(new_data)
return new_data
Expand Down
166 changes: 166 additions & 0 deletions specforge/modeling/target/target_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,169 @@ def preprocess(self, input_ids, target, loss_mask):
input_ids = padding(input_ids, left=False)
loss_mask = loss_mask[..., None]
return input_ids, target, loss_mask

def get_rope_index(
self,
input_ids: torch.LongTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:

if hasattr(self.config, "vision_config"):
spatial_merge_size = self.config.vision_config.spatial_merge_size
else:
spatial_merge_size = getattr(self.config, "spatial_merge_size", 2)

image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id

if video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(
video_grid_thw, video_grid_thw[:, 0], dim=0
)
video_grid_thw[:, 0] = 1

if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)

position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)

image_index, video_index = 0, 0
mrope_position_deltas = []

for i, curr_input_ids in enumerate(total_input_ids):
curr_mask = attention_mask[i] == 1
masked_ids = curr_input_ids[curr_mask]

vision_start_indices = torch.argwhere(
masked_ids == vision_start_token_id
).squeeze(1)
vision_tokens = masked_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()

input_tokens = masked_ids.tolist()
llm_pos_ids_list = []
st = 0
remain_images, remain_videos = image_nums, video_nums

for _ in range(image_nums + video_nums):
ed_image = (
input_tokens.index(image_token_id, st)
if image_token_id in input_tokens[st:] and remain_images > 0
else len(input_tokens) + 1
)
ed_video = (
input_tokens.index(video_token_id, st)
if video_token_id in input_tokens[st:] and remain_videos > 0
else len(input_tokens) + 1
)

if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video

llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st

st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)

t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)

llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w

if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, curr_mask] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(curr_input_ids)
)

mrope_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = (
position_ids.unsqueeze(0)
.expand(3, -1, -1)
.to(attention_mask.device)
)
max_pos = position_ids.max(0)[0].max(-1, keepdim=True)[0]
mrope_deltas = max_pos + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)

return position_ids, mrope_deltas
Loading