Skip to content

Commit 2ef8024

Browse files
Mandy3311hukongyi
authored andcommitted
feat: DFlash VLM training support with SGLang backend
Co-authored-by: hukongyi <hukongyi@cmbchina.com>
1 parent 961ca7c commit 2ef8024

9 files changed

Lines changed: 384 additions & 97 deletions

File tree

scripts/train_dflash.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tqdm import tqdm
2121
from transformers import AutoConfig, AutoTokenizer
2222

23-
from datasets import load_dataset
23+
from datasets import load_dataset, load_from_disk
2424
from specforge.args import SGLangBackendArgs, TrackerArgs
2525
from specforge.core.dflash import OnlineDFlashModel
2626
from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders
@@ -35,6 +35,8 @@
3535
from specforge.tracker import create_tracker
3636
from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank
3737

38+
logging.getLogger("sglang.srt.mem_cache.memory_pool").setLevel(logging.WARNING)
39+
3840

3941
def parse_args():
4042
parser = argparse.ArgumentParser(description="Train DFlash Draft Model")
@@ -80,6 +82,13 @@ def parse_args():
8082
help="Gamma for exponential loss decay weighting (paper Eq.4). "
8183
"Suggested: 7 for block_size=16, 5 for 10, 4 for 8. None disables.",
8284
)
85+
model_group.add_argument(
86+
"--embed-key", type=str, default="model.language_model.embed_tokens.weight"
87+
)
88+
model_group.add_argument("--lm-head-key", type=str, default="lm_head.weight")
89+
model_group.add_argument("--is-vlm", action="store_true")
90+
model_group.add_argument("--min-pixels", type=int, default=50176)
91+
model_group.add_argument("--max-pixels", type=int, default=802816)
8392

8493
dataset_group = parser.add_argument_group("dataset")
8594
dataset_group.add_argument("--train-data-path", type=str, required=True)
@@ -190,7 +199,16 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
190199
return target_model, draft_model
191200

192201

193-
def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]:
202+
def _load_raw_dataset(data_path: str):
203+
"""load jsonl"""
204+
if os.path.isdir(data_path):
205+
return load_from_disk(data_path)
206+
return load_dataset("json", data_files=data_path)["train"]
207+
208+
209+
def build_dataloader(
210+
args, tokenizer, processor=None
211+
) -> Tuple[DataLoader, Optional[DataLoader]]:
194212
"""Build train and eval dataloaders."""
195213
import hashlib
196214

@@ -202,7 +220,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
202220
)
203221
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
204222

205-
train_dataset = load_dataset("json", data_files=args.train_data_path)["train"]
223+
train_dataset = _load_raw_dataset(args.train_data_path)
206224
train_eagle3_dataset = build_eagle3_dataset(
207225
dataset=train_dataset,
208226
tokenizer=tokenizer,
@@ -212,8 +230,9 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
212230
cache_dir=os.path.join(args.cache_dir, "processed_dataset"),
213231
cache_key=cache_key,
214232
num_proc=args.build_dataset_num_proc,
233+
is_vlm=args.is_vlm,
234+
processor=processor,
215235
)
216-
217236
min_loss_tokens = 2 * args.block_size
218237
original_size = len(train_eagle3_dataset)
219238
train_eagle3_dataset = train_eagle3_dataset.filter(
@@ -229,24 +248,28 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
229248
num_workers=args.dataloader_num_workers,
230249
shuffle=True,
231250
process_group=get_dp_group(),
251+
is_vlm=args.is_vlm,
232252
)
233253

234254
eval_dataloader = None
235255
if args.eval_data_path:
236-
eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"]
256+
eval_dataset = _load_raw_dataset(args.eval_data_path)
237257
eval_eagle3_dataset = build_eagle3_dataset(
238258
dataset=eval_dataset,
239259
tokenizer=tokenizer,
240260
chat_template=args.chat_template,
241261
max_length=args.max_length,
242262
is_preformatted=args.is_preformatted,
263+
is_vlm=args.is_vlm,
264+
processor=processor,
243265
)
244266
eval_dataloader = prepare_dp_dataloaders(
245267
eval_eagle3_dataset,
246268
args.batch_size,
247269
num_workers=args.dataloader_num_workers,
248270
shuffle=False,
249271
process_group=get_dp_group(),
272+
is_vlm=args.is_vlm,
250273
)
251274

252275
return train_dataloader, eval_dataloader
@@ -353,11 +376,18 @@ def main():
353376
f"Provided ckpt dir {args.ckpt_dir} is not a valid directory."
354377
)
355378

379+
start_epoch = 0
380+
global_step = 0
381+
ckpt_info = None
382+
356383
if args.resume and os.path.isdir(args.output_dir):
357384
draft_model_last_checkpoint, ckpt_info = get_last_checkpoint(
358385
args.output_dir, prefix=r"epoch_\d+_step"
359386
)
360-
print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}")
387+
if ckpt_info:
388+
start_epoch = ckpt_info[0]
389+
global_step = ckpt_info[1]
390+
print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}")
361391

362392
resume_state = None
363393
if draft_model_last_checkpoint:
@@ -380,7 +410,23 @@ def main():
380410
f"step {resume_state['global_step']}"
381411
)
382412

383-
tokenizer = AutoTokenizer.from_pretrained(args.target_model_path)
413+
tokenizer = AutoTokenizer.from_pretrained(
414+
args.target_model_path, trust_remote_code=args.trust_remote_code
415+
)
416+
417+
processor = None
418+
if args.is_vlm:
419+
from transformers import AutoProcessor
420+
421+
processor = AutoProcessor.from_pretrained(
422+
args.target_model_path,
423+
min_pixels=args.min_pixels,
424+
max_pixels=args.max_pixels,
425+
trust_remote_code=args.trust_remote_code,
426+
)
427+
print_on_rank0(
428+
f"Loaded VLM processor (min_pixels={args.min_pixels}, max_pixels={args.max_pixels})"
429+
)
384430

385431
if args.mask_token_id is not None:
386432
mask_token_id = args.mask_token_id
@@ -396,18 +442,17 @@ def main():
396442
draft_model.config.dflash_config["target_layer_ids"] = draft_model.target_layer_ids
397443
print_on_rank0(f"dflash_config: {draft_model.config.dflash_config}")
398444

399-
train_dataloader, eval_dataloader = build_dataloader(args, tokenizer)
400-
445+
train_dataloader, eval_dataloader = build_dataloader(args, tokenizer, processor)
401446
steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps)
402447
total_steps = args.num_epochs * steps_per_epoch
403448
print_on_rank0(f"Total training steps: {total_steps}")
404449

405450
print_on_rank0("Loading target embeddings and head...")
406451
target_components = TargetEmbeddingsAndHead.from_pretrained(
407452
args.target_model_path,
408-
embed_key="model.embed_tokens.weight", # Adjust if Qwen/Llama differs
409-
lm_head_key="lm_head.weight",
410-
device="cuda",
453+
embed_key=args.embed_key,
454+
lm_head_key=args.lm_head_key,
455+
device=torch.cuda.current_device(),
411456
trust_remote_code=args.trust_remote_code,
412457
)
413458

@@ -441,8 +486,6 @@ def main():
441486
total_steps=total_steps,
442487
)
443488

444-
start_epoch = ckpt_info[0]
445-
global_step = ckpt_info[1]
446489
if resume_state is not None:
447490
optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"])
448491
start_epoch = resume_state["epoch"]
@@ -475,11 +518,30 @@ def main():
475518
continue
476519
global_step += 1
477520

521+
input_ids_cpu = data["input_ids"]
522+
attention_mask_cpu = data["attention_mask"]
523+
loss_mask_cpu = data["loss_mask"]
524+
478525
input_ids = data["input_ids"].cuda()
479-
attention_mask = data["attention_mask"].cuda()
480526
loss_mask = data["loss_mask"].cuda()
527+
pixel_values = None
528+
image_grid_thw_cpu = None
529+
if (
530+
args.is_vlm
531+
and "pixel_values" in data
532+
and data["pixel_values"] is not None
533+
):
534+
pixel_values = data["pixel_values"].cuda()
535+
image_grid_thw_cpu = [
536+
thw.squeeze() if thw is not None else None
537+
for thw in data["image_grid_thw"]
538+
]
481539
target_output = target_model.generate_dflash_data(
482-
input_ids, attention_mask, loss_mask
540+
input_ids_cpu,
541+
attention_mask_cpu,
542+
loss_mask_cpu,
543+
pixel_values=pixel_values,
544+
image_grid_thw=image_grid_thw_cpu,
483545
)
484546
hidden_states = target_output.hidden_states.cuda() # Ensure on GPU
485547

specforge/data/preprocessing.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -215,33 +215,79 @@ def preprocess_vlm_conversations(
215215
# Note: currently, we assume that each example has only one image
216216
for i, image in enumerate(examples["image"]):
217217
source = examples["conversations"][i]
218-
messages = [{"role": "system", "content": system_prompt}]
218+
messages = []
219+
# messages = [{"role": "system", "content": system_prompt}]
219220
if not source:
220221
# if the source is None, skip it
221222
continue
222223

224+
if not image:
225+
text_messages = []
226+
convroles = ["user", "assistant"]
227+
for j, sentence in enumerate(source):
228+
role = sentence["role"]
229+
assert role == convroles[j % 2], f"unexpected role {role}"
230+
text_messages.append({"role": role, "content": sentence["content"]})
231+
conversation = processor.apply_chat_template(
232+
text_messages,
233+
tokenize=False,
234+
add_generation_prompt=False,
235+
)
236+
encoding = processor(
237+
text=[conversation],
238+
max_length=max_length,
239+
truncation=True,
240+
return_tensors="pt",
241+
return_offsets_mapping=True,
242+
add_special_tokens=False,
243+
)
244+
245+
input_ids = encoding.input_ids[0]
246+
offsets = encoding.offset_mapping[0]
247+
248+
# get conversation with image info for loss mask generation
249+
decoded_conversation = processor.tokenizer.decode(
250+
encoding.input_ids[0], skip_special_tokens=False
251+
)
252+
253+
# Apply loss mask
254+
loss_mask = _apply_loss_mask_from_chat_template(
255+
decoded_conversation, offsets, chat_template
256+
)
257+
results["input_ids"].append(input_ids[None, :])
258+
results["loss_mask"].append(loss_mask[None, :])
259+
results["attention_mask"].append(torch.ones_like(loss_mask)[None, :])
260+
results["pixel_values"].append(torch.empty(0, 0).float())
261+
results["image_grid_thw"].append([])
262+
continue
263+
223264
if source[0]["role"] != "user":
224265
# if the first message is not from user, skip it
225266
source = source[1:]
226267

227268
convroles = ["user", "assistant"]
269+
has_added_image = False
228270
for j, sentence in enumerate(source):
229271
role = sentence["role"]
230272
assert role == convroles[j % 2], f"unexpected role {role}"
231273
if role == "user":
232274
# if the message is from user and has image, process the image
233-
messages.append(
234-
{
235-
"role": role,
236-
"content": [
237-
{
238-
"type": "image",
239-
"image": image,
240-
},
241-
{"type": "text", "text": sentence["content"]},
242-
],
243-
}
244-
)
275+
if not has_added_image:
276+
messages.append(
277+
{
278+
"role": role,
279+
"content": [
280+
{
281+
"type": "image",
282+
"image": image,
283+
},
284+
{"type": "text", "text": sentence["content"]},
285+
],
286+
}
287+
)
288+
has_added_image = True
289+
else:
290+
messages.append({"role": role, "content": sentence["content"]})
245291
else:
246292
messages.append({"role": role, "content": sentence["content"]})
247293

specforge/data/utils.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,30 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
218218
batch_loss_mask = torch.cat(
219219
[self.paddingtensor2D(item["loss_mask"], max_length) for item in features]
220220
)
221-
batch_pixel_values = torch.cat(
222-
[item["pixel_values"] for item in features], dim=0
223-
)
224-
batch_image_grid_thw = torch.cat(
225-
[item["image_grid_thw"] for item in features], dim=0
226-
)
221+
# Collect pixel_values and image_grid_thw per sample.
222+
# Image samples have non-empty pixel_values; text-only samples have empty tensors.
223+
all_pixel_values = []
224+
all_image_grid_thw = []
225+
for item in features:
226+
pv = item.get("pixel_values")
227+
thw = item.get("image_grid_thw")
228+
if pv is not None and isinstance(pv, torch.Tensor) and pv.numel() > 0:
229+
all_pixel_values.append(pv)
230+
all_image_grid_thw.append(thw)
231+
else:
232+
all_image_grid_thw.append(None)
233+
234+
if all_pixel_values:
235+
batch_pixel_values = torch.cat(all_pixel_values, dim=0)
236+
else:
237+
batch_pixel_values = None
238+
239+
# If all samples are text-only, set image_grid_thw to None
240+
if all(thw is None for thw in all_image_grid_thw):
241+
batch_image_grid_thw = None
242+
else:
243+
batch_image_grid_thw = all_image_grid_thw
244+
227245
batch = {
228246
"input_ids": batch_input_ids,
229247
"attention_mask": batch_attention_mask,
@@ -304,17 +322,10 @@ def prepare_dp_dataloaders(
304322

305323

306324
def parse_harmony_message_content(content):
307-
"""
308-
解析 content 字符串中的 Harmony 格式。
309-
如果匹配到 Harmony 格式,返回包含 channel 和 content 的列表;
310-
否则,返回原内容并标记为默认 channel。
311-
"""
312-
# 匹配 <|channel|>xxx<|message|>yyy<|end|>
313325
pattern = r"<\|channel\|>(.*?)<\|message\|>(.*?)<\|end|>"
314326
matches = re.findall(pattern, content, re.DOTALL)
315327

316328
if not matches:
317-
# 如果没有匹配到 Harmony 标签,视作普通文本
318329
return [{"channel": "text", "content": content}]
319330

320331
results = []
@@ -324,22 +335,17 @@ def parse_harmony_message_content(content):
324335

325336

326337
def process_harmony_conversations(conversation):
327-
"""
328-
处理传入的 list[list[dict]] 结构
329-
"""
330338
new_conversation = []
331339
for msg in conversation:
332340
role = msg.get("role")
333341
original_content = msg.get("content", "")
334342

335-
# 解析 content 中的 Harmony 结构
336343
segments = parse_harmony_message_content(original_content)
337344

338-
# 为每个解析出的通道生成一个新的消息字典
339345
for seg in segments:
340346
new_msg = {
341347
"role": role,
342-
"channel": seg["channel"], # 新增字段标识通道
348+
"channel": seg["channel"],
343349
"content": seg["content"],
344350
}
345351
new_conversation.append(new_msg)

0 commit comments

Comments
 (0)