Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions scripts/prepare_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class DataPoint:
loss_mask: torch.Tensor
hidden_state: torch.Tensor
aux_hidden_state: Optional[torch.Tensor] = None
pixel_values: Optional[torch.Tensor] = None
image_grid_thw: Optional[torch.Tensor] = None


def parse_args():
Expand Down Expand Up @@ -187,14 +189,19 @@ def build_target_model(
)
else:
target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs()
if hasattr(model_config, "dtype") and model_config.dtype is not None:
torch_dtype = model_config.dtype
elif hasattr(model_config, "text_config") and hasattr(
model_config.text_config, "dtype"
):
torch_dtype = model_config.text_config.dtype
else:
torch_dtype = getattr(model_config, "torch_dtype", "bfloat16")
Comment on lines +192 to +199
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining torch_dtype can be simplified and made more robust by using getattr with defaults and explicitly checking for None values in nested configurations.

Suggested change
if hasattr(model_config, "dtype") and model_config.dtype is not None:
torch_dtype = model_config.dtype
elif hasattr(model_config, "text_config") and hasattr(model_config.text_config, "dtype"):
torch_dtype = model_config.text_config.dtype
else:
torch_dtype = getattr(model_config, "torch_dtype", "bfloat16")
torch_dtype = getattr(model_config, "dtype", None)
if torch_dtype is None and hasattr(model_config, "text_config"):
torch_dtype = getattr(model_config.text_config, "dtype", None)
if torch_dtype is None:
torch_dtype = getattr(model_config, "torch_dtype", "bfloat16")


target_model = get_eagle3_target_model(
pretrained_model_name_or_path=args.target_model_path,
backend="sglang", # we set this as the default backend to minimize precision mismatch in training and serving
torch_dtype=(
model_config.dtype
if hasattr(model_config, "dtype")
else model_config.torch_dtype
),
torch_dtype=torch_dtype,
device="cuda",
cache_dir=args.model_download_dir,
trust_remote_code=args.trust_remote_code,
Expand Down Expand Up @@ -488,6 +495,8 @@ def generate(
"attention_mask": batch["attention_mask"][valid_indices_in_batch],
"loss_mask": batch["loss_mask"][valid_indices_in_batch],
}
pixel_values = batch["pixel_values"][valid_indices_in_batch]
image_grid_thw = batch["image_grid_thw"][valid_indices_in_batch]
del batch
if num_valid == 0:
# Data has already been generated, no sample processing, update progress bar.
Expand Down Expand Up @@ -542,6 +551,8 @@ def generate(
data_point = DataPoint(
input_ids=filtered_batch["input_ids"][i].clone(),
loss_mask=filtered_batch["loss_mask"][i].clone(),
pixel_values=pixel_values[i].clone(),
image_grid_thw=image_grid_thw[i].clone(),
Comment on lines +554 to +555
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing pixel_values[i] and image_grid_thw[i] will cause a TypeError or AttributeError if the variables are None (which occurs for non-VLM models).

Suggested change
pixel_values=pixel_values[i].clone(),
image_grid_thw=image_grid_thw[i].clone(),
pixel_values=pixel_values[i].clone() if pixel_values is not None else None,
image_grid_thw=image_grid_thw[i].clone() if image_grid_thw is not None else None,

hidden_state=last_hidden_states,
aux_hidden_state=aux_hidden_states,
)
Expand Down
Loading