2020from tqdm import tqdm
2121from transformers import AutoConfig , AutoTokenizer
2222
23- from datasets import load_dataset
23+ from datasets import load_dataset , load_from_disk
2424from specforge .args import SGLangBackendArgs , TrackerArgs
2525from specforge .core .dflash import OnlineDFlashModel
2626from specforge .data import build_eagle3_dataset , prepare_dp_dataloaders
3535from specforge .tracker import create_tracker
3636from specforge .utils import get_last_checkpoint , print_on_rank0 , print_with_rank
3737
38+ logging .getLogger ("sglang.srt.mem_cache.memory_pool" ).setLevel (logging .WARNING )
39+
3840
3941def parse_args ():
4042 parser = argparse .ArgumentParser (description = "Train DFlash Draft Model" )
@@ -80,6 +82,13 @@ def parse_args():
8082 help = "Gamma for exponential loss decay weighting (paper Eq.4). "
8183 "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. None disables." ,
8284 )
85+ model_group .add_argument (
86+ "--embed-key" , type = str , default = "model.language_model.embed_tokens.weight"
87+ )
88+ model_group .add_argument ("--lm-head-key" , type = str , default = "lm_head.weight" )
89+ model_group .add_argument ("--is-vlm" , action = "store_true" )
90+ model_group .add_argument ("--min-pixels" , type = int , default = 50176 )
91+ model_group .add_argument ("--max-pixels" , type = int , default = 802816 )
8392
8493 dataset_group = parser .add_argument_group ("dataset" )
8594 dataset_group .add_argument ("--train-data-path" , type = str , required = True )
@@ -190,7 +199,16 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
190199 return target_model , draft_model
191200
192201
193- def build_dataloader (args , tokenizer ) -> Tuple [DataLoader , Optional [DataLoader ]]:
202+ def _load_raw_dataset (data_path : str ):
203+ """load jsonl"""
204+ if os .path .isdir (data_path ):
205+ return load_from_disk (data_path )
206+ return load_dataset ("json" , data_files = data_path )["train" ]
207+
208+
209+ def build_dataloader (
210+ args , tokenizer , processor = None
211+ ) -> Tuple [DataLoader , Optional [DataLoader ]]:
194212 """Build train and eval dataloaders."""
195213 import hashlib
196214
@@ -202,7 +220,7 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
202220 )
203221 cache_key = hashlib .md5 (cache_params_string .encode ()).hexdigest ()
204222
205- train_dataset = load_dataset ( "json" , data_files = args .train_data_path )[ "train" ]
223+ train_dataset = _load_raw_dataset ( args .train_data_path )
206224 train_eagle3_dataset = build_eagle3_dataset (
207225 dataset = train_dataset ,
208226 tokenizer = tokenizer ,
@@ -212,8 +230,9 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
212230 cache_dir = os .path .join (args .cache_dir , "processed_dataset" ),
213231 cache_key = cache_key ,
214232 num_proc = args .build_dataset_num_proc ,
233+ is_vlm = args .is_vlm ,
234+ processor = processor ,
215235 )
216-
217236 min_loss_tokens = 2 * args .block_size
218237 original_size = len (train_eagle3_dataset )
219238 train_eagle3_dataset = train_eagle3_dataset .filter (
@@ -229,24 +248,28 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]
229248 num_workers = args .dataloader_num_workers ,
230249 shuffle = True ,
231250 process_group = get_dp_group (),
251+ is_vlm = args .is_vlm ,
232252 )
233253
234254 eval_dataloader = None
235255 if args .eval_data_path :
236- eval_dataset = load_dataset ( "json" , data_files = args .eval_data_path )[ "train" ]
256+ eval_dataset = _load_raw_dataset ( args .eval_data_path )
237257 eval_eagle3_dataset = build_eagle3_dataset (
238258 dataset = eval_dataset ,
239259 tokenizer = tokenizer ,
240260 chat_template = args .chat_template ,
241261 max_length = args .max_length ,
242262 is_preformatted = args .is_preformatted ,
263+ is_vlm = args .is_vlm ,
264+ processor = processor ,
243265 )
244266 eval_dataloader = prepare_dp_dataloaders (
245267 eval_eagle3_dataset ,
246268 args .batch_size ,
247269 num_workers = args .dataloader_num_workers ,
248270 shuffle = False ,
249271 process_group = get_dp_group (),
272+ is_vlm = args .is_vlm ,
250273 )
251274
252275 return train_dataloader , eval_dataloader
@@ -353,11 +376,18 @@ def main():
353376 f"Provided ckpt dir { args .ckpt_dir } is not a valid directory."
354377 )
355378
379+ start_epoch = 0
380+ global_step = 0
381+ ckpt_info = None
382+
356383 if args .resume and os .path .isdir (args .output_dir ):
357384 draft_model_last_checkpoint , ckpt_info = get_last_checkpoint (
358385 args .output_dir , prefix = r"epoch_\d+_step"
359386 )
360- print_on_rank0 (f"Last checkpoint detected: { draft_model_last_checkpoint } " )
387+ if ckpt_info :
388+ start_epoch = ckpt_info [0 ]
389+ global_step = ckpt_info [1 ]
390+ print_on_rank0 (f"Last checkpoint detected: { draft_model_last_checkpoint } " )
361391
362392 resume_state = None
363393 if draft_model_last_checkpoint :
@@ -380,7 +410,23 @@ def main():
380410 f"step { resume_state ['global_step' ]} "
381411 )
382412
383- tokenizer = AutoTokenizer .from_pretrained (args .target_model_path )
413+ tokenizer = AutoTokenizer .from_pretrained (
414+ args .target_model_path , trust_remote_code = args .trust_remote_code
415+ )
416+
417+ processor = None
418+ if args .is_vlm :
419+ from transformers import AutoProcessor
420+
421+ processor = AutoProcessor .from_pretrained (
422+ args .target_model_path ,
423+ min_pixels = args .min_pixels ,
424+ max_pixels = args .max_pixels ,
425+ trust_remote_code = args .trust_remote_code ,
426+ )
427+ print_on_rank0 (
428+ f"Loaded VLM processor (min_pixels={ args .min_pixels } , max_pixels={ args .max_pixels } )"
429+ )
384430
385431 if args .mask_token_id is not None :
386432 mask_token_id = args .mask_token_id
@@ -396,18 +442,17 @@ def main():
396442 draft_model .config .dflash_config ["target_layer_ids" ] = draft_model .target_layer_ids
397443 print_on_rank0 (f"dflash_config: { draft_model .config .dflash_config } " )
398444
399- train_dataloader , eval_dataloader = build_dataloader (args , tokenizer )
400-
445+ train_dataloader , eval_dataloader = build_dataloader (args , tokenizer , processor )
401446 steps_per_epoch = math .ceil (len (train_dataloader ) / args .accumulation_steps )
402447 total_steps = args .num_epochs * steps_per_epoch
403448 print_on_rank0 (f"Total training steps: { total_steps } " )
404449
405450 print_on_rank0 ("Loading target embeddings and head..." )
406451 target_components = TargetEmbeddingsAndHead .from_pretrained (
407452 args .target_model_path ,
408- embed_key = "model.embed_tokens.weight" , # Adjust if Qwen/Llama differs
409- lm_head_key = "lm_head.weight" ,
410- device = " cuda" ,
453+ embed_key = args . embed_key ,
454+ lm_head_key = args . lm_head_key ,
455+ device = torch . cuda . current_device () ,
411456 trust_remote_code = args .trust_remote_code ,
412457 )
413458
@@ -441,8 +486,6 @@ def main():
441486 total_steps = total_steps ,
442487 )
443488
444- start_epoch = ckpt_info [0 ]
445- global_step = ckpt_info [1 ]
446489 if resume_state is not None :
447490 optimizer .scheduler .load_state_dict (resume_state ["scheduler_state_dict" ])
448491 start_epoch = resume_state ["epoch" ]
@@ -475,11 +518,30 @@ def main():
475518 continue
476519 global_step += 1
477520
521+ input_ids_cpu = data ["input_ids" ]
522+ attention_mask_cpu = data ["attention_mask" ]
523+ loss_mask_cpu = data ["loss_mask" ]
524+
478525 input_ids = data ["input_ids" ].cuda ()
479- attention_mask = data ["attention_mask" ].cuda ()
480526 loss_mask = data ["loss_mask" ].cuda ()
527+ pixel_values = None
528+ image_grid_thw_cpu = None
529+ if (
530+ args .is_vlm
531+ and "pixel_values" in data
532+ and data ["pixel_values" ] is not None
533+ ):
534+ pixel_values = data ["pixel_values" ].cuda ()
535+ image_grid_thw_cpu = [
536+ thw .squeeze () if thw is not None else None
537+ for thw in data ["image_grid_thw" ]
538+ ]
481539 target_output = target_model .generate_dflash_data (
482- input_ids , attention_mask , loss_mask
540+ input_ids_cpu ,
541+ attention_mask_cpu ,
542+ loss_mask_cpu ,
543+ pixel_values = pixel_values ,
544+ image_grid_thw = image_grid_thw_cpu ,
483545 )
484546 hidden_states = target_output .hidden_states .cuda () # Ensure on GPU
485547
0 commit comments