diff --git a/.gitignore b/.gitignore index 5486de379e..39991378c0 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,17 @@ Sessionx.vim # macOS dir files .DS_Store + +# Ignore everything inside scripts/ +scripts/* + +# Keep Python files in scripts/ +!scripts/*.py + +# Keep contents of these subdirectories +!scripts/example/ +!scripts/example/** +!scripts/generate/ +!scripts/generate/** +!scripts/estimate/ +!scripts/estimate/** diff --git a/README.md b/README.md index fa9b3ba35c..074189aa5e 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,29 @@ srun torchrun --nnodes 2 If your gpu count per node is not 8, adjust `--nproc_per_node` in the torchrun command and `#SBATCH --gpus-per-task` in the SBATCH command section. +## (NOUS) training with sample packing and multimodality + +### Training Qwen3-8B with sample packing +To preprocess and pack a textonly chat dataset, run `scripts/preprocess_data.py`: +``` +python3 scripts/preprocess_data.py --dataset NousResearch/Hermes-3-Dataset --tokenizer Qwen/Qwen3-8B --chat --pack-to-sequence-length 8000 --split "train[:1000]" --save-to-disk ./dataset +``` + +Qwen3-8B can be trained using this dataset: +``` +CONFIG_FILE="./torchtitan/models/qwen3/train_configs/qwen3_8b_finetuning.toml" ./run_train.sh +``` + +## Training Mistral Small 3.1 with multimodal sample packing +To preprocess and pack a multimodal chat dataset, run `scripts/preprocess_multimodal_data.py`: +``` +python3 scripts/preprocess_multimodal_data.py --dataset /home/shared/datasets/cambrian_sample.json --preprocessor mistralai/Mistral-Small-3.1-24B-Instruct-2503 --chat --pack-to-sequence-length 8000 --split "train" --save-to-disk ./multimodal_dataset --limit 1000 +``` + +Mistral Small 3.1 can be trained using this dataset: +``` +CONFIG_FILE="./torchtitan/models/mistral3/train_configs/mistral24b_finetuning.toml" ./run_train.sh +``` ## Citation diff --git a/scripts/convert.py b/scripts/convert.py new file mode 100644 index 0000000000..0534192fbe --- /dev/null +++ b/scripts/convert.py @@ -0,0 +1,59 @@ +import re +from datasets import load_dataset +from datasets.utils.info_utils import VerificationMode + +# Load your dataset (replace 'your_dataset_name' with the actual Hugging Face dataset name or path) +ds = load_dataset("nvidia/Llama-Nemotron-VLM-Dataset-v1", verification_mode=VerificationMode.NO_CHECKS, split="vqa_8") + +def process_conversation(row): + image_path = row["image"] + original_conv = row["conversations"] + + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + } + ] + + # Assuming the conversation alternates starting with human, and image is only in the first human message + for i, turn in enumerate(original_conv): + role = "user" if turn["from"] == "human" else "assistant" + value = turn["value"] + + if role == "user" and i == 0 and "" in value: + # Split the value around + parts = re.split(r'(\n?\n?)', value) + content = [] + for part in parts: + if re.match(r'\n?\n?', part): + content.append({"type": "image", "path": './ChartQA Dataset/' + image_path}) + elif part.strip(): + content.append({"type": "text", "text": part.strip()}) + else: + content = [{"type": "text", "text": value.strip()}] + + messages.append({ + "role": role, + "content": content + }) + + # The format has an outer list with a dict containing "messages" + new_conv = [{"messages": messages}] + + return {"conversations": messages} + +# Apply the transformation and keep only the new "conversations" column +new_ds = ds.map( + process_conversation, + remove_columns=ds.column_names +) + +# Optionally, push to Hugging Face or save +# new_ds.push_to_hub("new_dataset_name") +# or +new_ds.save_to_disk("ChartQA_Subset") + +#print(new_ds[0]) diff --git a/scripts/convert_deepseek_hf_to_dcp.py b/scripts/convert_deepseek_hf_to_dcp.py index ae01eed0dd..3bc04b5019 100644 --- a/scripts/convert_deepseek_hf_to_dcp.py +++ b/scripts/convert_deepseek_hf_to_dcp.py @@ -1,195 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import argparse -import glob -import os from pathlib import Path import torch import torch.distributed.checkpoint as DCP -from huggingface_hub import snapshot_download -from safetensors import safe_open -from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.logging import logger +from torchtitan.models.mistral3.model.model import precompute_freqs_cis -from transformers import AutoConfig, AutoTokenizer # noqa F401 +from transformers import AutoTokenizer, Mistral3Config -@torch.inference_mode() -def convert_deepseekv3_weights(deepseek_model, output_dir): - # Download the model files locally - if os.path.exists(deepseek_model): - local_path = deepseek_model - else: - local_path = snapshot_download( - repo_id=deepseek_model, - allow_patterns=["*.safetensors", "config.json", "tokenizer*"], - ) - tok = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True) - config = AutoConfig.from_pretrained(local_path, trust_remote_code=True) - n_layers = config.num_hidden_layers - dim = config.hidden_size - - logger.info( - f"Loading original DeepSeek V3 weights from {deepseek_model} using safetensors" +# permute for sliced rotary +def permute(w, n_heads, dim1, dim2): + return ( + w.view(n_heads, dim1 // n_heads // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) ) - # Find all safetensors files - safetensors_files = glob.glob(f"{local_path}/*.safetensors") - if not safetensors_files: - raise ValueError( - "No safetensors files found in the downloaded model directory. Ensure the model uses safetensors format." - ) - # Load state dict directly from safetensors files - hf_state_dict = {} - for file in safetensors_files: - with safe_open(file, framework="pt", device="cpu") as f: - for key in f.keys(): - tensor = f.get_tensor(key) - logger.info(f"Read {key}, shape={tensor.shape}, dtype{tensor.dtype}") - hf_state_dict[key] = tensor +# And reversed +def reverse_permute(w, n_heads, dim1, dim2): + return ( + w.view(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + - logger.info("Converting...") +@torch.inference_mode() +def convert_mistral3_weights(mistral_model, output_dir, max_seq_len: int): + # Loading the model directly might be too large, so we'll use safetensors to load the weights + from safetensors import safe_open + import os + import json + + # Load the config + config_path = os.path.join(mistral_model, "config.json") + with open(config_path, "r") as f: + config_dict = json.load(f) + + config = Mistral3Config.from_dict(config_dict) + tok = AutoTokenizer.from_pretrained(mistral_model) + + # Find all safetensors files + index_path = os.path.join(mistral_model, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + all_files = set(weight_map.values()) + else: + # If no index, look for a single safetensors file + safetensors_files = [f for f in os.listdir(mistral_model) if f.endswith(".safetensors")] + if len(safetensors_files) == 1: + all_files = [safetensors_files[0]] + else: + raise ValueError("Multiple safetensors files found without an index file") + + # Extract language model parameters from the text config + text_config = config.text_config + n_layers = text_config.num_hidden_layers + n_heads = text_config.num_attention_heads + dim = text_config.hidden_size + dims_per_head = dim // n_heads + + logger.info(f"Loading original Mistral3 weights from {mistral_model}") + state_dict = {} + n_heads_per_shard = n_heads + num_key_value_heads = text_config.num_key_value_heads + n_kv_heads_per_shard = num_key_value_heads + key_value_dim = dims_per_head * num_key_value_heads + + # Load and process weights + hf_state_dict = {} + for filename in all_files: + filepath = os.path.join(mistral_model, filename) + with safe_open(filepath, framework="pt") as f: + for key in f.keys(): + hf_state_dict[key] = f.get_tensor(key) + + # Process language model layers for layer in range(n_layers): - state_dict[f"layers.{layer}.attention_norm.weight"] = hf_state_dict[ - f"model.layers.{layer}.input_layernorm.weight" - ] - state_dict[f"layers.{layer}.ffn_norm.weight"] = hf_state_dict[ - f"model.layers.{layer}.post_attention_layernorm.weight" - ] + # Map from HF to torchtitan structure + # Based on the keys we found in the checkpoint - # Attention weights - if config.q_lora_rank is None: - state_dict[f"layers.{layer}.attention.wq.weight"] = hf_state_dict[ - f"model.layers.{layer}.self_attn.q_proj.weight" - ] - else: - state_dict[f"layers.{layer}.attention.wq_a.weight"] = hf_state_dict[ - f"model.layers.{layer}.self_attn.q_a_proj.weight" - ] - state_dict[f"layers.{layer}.attention.q_norm.weight"] = hf_state_dict[ - f"model.layers.{layer}.self_attn.q_a_layernorm.weight" - ] - state_dict[f"layers.{layer}.attention.wq_b.weight"] = hf_state_dict[ - f"model.layers.{layer}.self_attn.q_b_proj.weight" - ] - - state_dict[f"layers.{layer}.attention.wkv_a.weight"] = hf_state_dict[ - f"model.layers.{layer}.self_attn.kv_a_proj_with_mqa.weight" - ] - state_dict[f"layers.{layer}.attention.kv_norm.weight"] = hf_state_dict[ - f"model.layers.{layer}.self_attn.kv_a_layernorm.weight" - ] - state_dict[f"layers.{layer}.attention.wkv_b.weight"] = hf_state_dict[ - f"model.layers.{layer}.self_attn.kv_b_proj.weight" + # Norm layers + state_dict[f"language_model.layers.{layer}.ln_attn.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.input_layernorm.weight" ] - state_dict[f"layers.{layer}.attention.wo.weight"] = hf_state_dict[ - f"model.layers.{layer}.self_attn.o_proj.weight" + state_dict[f"language_model.layers.{layer}.ln_mlp.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.post_attention_layernorm.weight" ] - # MLP or MoE - if layer < config.first_k_dense_replace: - # Regular MLP (FeedForward) - state_dict[f"layers.{layer}.feed_forward.w1.weight"] = hf_state_dict[ - f"model.layers.{layer}.mlp.gate_proj.weight" - ] - state_dict[f"layers.{layer}.feed_forward.w3.weight"] = hf_state_dict[ - f"model.layers.{layer}.mlp.up_proj.weight" - ] - state_dict[f"layers.{layer}.feed_forward.w2.weight"] = hf_state_dict[ - f"model.layers.{layer}.mlp.down_proj.weight" - ] - else: - # MoE - num_experts = config.n_routed_experts - hidden_dim = config.moe_intermediate_size - - # Router - state_dict[f"layers.{layer}.moe.router.gate.weight"] = hf_state_dict[ - f"model.layers.{layer}.mlp.gate.weight" - ] - - # Experts - w1_list = [] - w2_list = [] - w3_list = [] - for i in range(num_experts): - w1_list.append( - hf_state_dict[ - f"model.layers.{layer}.mlp.experts.{i}.gate_proj.weight" - ].t() - ) - w3_list.append( - hf_state_dict[ - f"model.layers.{layer}.mlp.experts.{i}.up_proj.weight" - ].t() - ) - w2_list.append( - hf_state_dict[ - f"model.layers.{layer}.mlp.experts.{i}.down_proj.weight" - ].t() + # Attention layers + for wn, hn, nh in [ + ("wq", "q_proj", n_heads_per_shard), + ("wk", "k_proj", n_kv_heads_per_shard), + ("wv", "v_proj", n_kv_heads_per_shard), + ]: + if wn != "wv": + # Need to reverse the permutation for sliced rotary + + state_dict[f"language_model.layers.{layer}.attn.{wn}.weight"] = reverse_permute( + hf_state_dict[f"language_model.model.layers.{layer}.self_attn.{hn}.weight"], + n_heads if wn == "wq" else num_key_value_heads, + dim1=4096 if wn == "wq" else int(key_value_dim*0.8), + dim2=dim, ) + else: - state_dict[f"layers.{layer}.moe.experts.w1"] = torch.stack(w1_list) - state_dict[f"layers.{layer}.moe.experts.w3"] = torch.stack(w3_list) - state_dict[f"layers.{layer}.moe.experts.w2"] = torch.stack(w2_list) - - bias_key = f"model.layers.{layer}.mlp.gate.e_score_correction_bias" - if bias_key in hf_state_dict: - state_dict[f"layers.{layer}.moe.router.expert_bias"] = hf_state_dict[ - bias_key + state_dict[f"language_model.layers.{layer}.attn.{wn}.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.self_attn.{hn}.weight" ] - # Ephemeral for training, but torchtitan expects it - state_dict[f"layers.{layer}.moe.tokens_per_expert"] = torch.zeros( - num_experts, dtype=torch.float32 - ) - # Shared expert (if exists) - if config.n_shared_experts is not None: - state_dict[f"layers.{layer}.moe.shared_expert.w1"] = ( - hf_state_dict[ - f"model.layers.{layer}.mlp.shared_experts.gate_proj.weight" - ] - .t() - .unsqueeze(0) - ) - state_dict[f"layers.{layer}.moe.shared_expert.w3"] = ( - hf_state_dict[ - f"model.layers.{layer}.mlp.shared_experts.up_proj.weight" - ] - .t() - .unsqueeze(0) - ) - state_dict[f"layers.{layer}.moe.shared_expert.w2"] = ( - hf_state_dict[ - f"model.layers.{layer}.mlp.shared_experts.down_proj.weight" - ] - .t() - .unsqueeze(0) - ) + state_dict[f"language_model.layers.{layer}.attn.wo.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.self_attn.o_proj.weight" + ] - state_dict["norm.weight"] = hf_state_dict["model.norm.weight"] - state_dict["tok_embeddings.weight"] = hf_state_dict["model.embed_tokens.weight"] - state_dict["output.weight"] = hf_state_dict["lm_head.weight"] + # Feed-forward layers + state_dict[f"language_model.layers.{layer}.mlp.w1.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.gate_proj.weight" + ] + state_dict[f"language_model.layers.{layer}.mlp.w2.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.down_proj.weight" + ] + state_dict[f"language_model.layers.{layer}.mlp.w3.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.up_proj.weight" + ] + + # Language model norm and embeddings + state_dict["language_model.norm.weight"] = hf_state_dict["language_model.model.norm.weight"] + + # Handling embeddings + state_dict["language_model.tok_embeddings.weight"] = hf_state_dict["language_model.model.embed_tokens.weight"] + # If fusion embedding exists in the HF model + + state_dict["language_model.output.weight"] = hf_state_dict["language_model.lm_head.weight"] + + # Vision tower components + if "vision_tower.ln_pre.weight" in hf_state_dict: + # Copy over vision tower weights, restructuring to put them under model.vision_encoder.pixtral_vision + vision_keys = [k for k in hf_state_dict.keys() if k.startswith("vision_tower.")] + for key in vision_keys: + # Replace vision_tower with vision_encoder.pixtral_vision in the key path + new_key = key.replace("vision_tower", "vision_encoder.pixtral_vision") + state_dict[new_key] = hf_state_dict[key] + + # Multi-modal projector + mm_keys = [k for k in hf_state_dict.keys() if k.startswith("multi_modal_projector.")] + for key in mm_keys: + state_dict["vision_encoder." + key] = hf_state_dict[key] + + # TODO figure out how to not hardcode + dims_per_head = 128 + + # NOTE: precompute freqs_cis because must be persisted by default in torchtitan + state_dict["language_model.freqs_cis"] = precompute_freqs_cis( + dims_per_head, + max_seq_len, + text_config.rope_theta, + ) logger.info(f"Writing to DCP at '{output_dir}'") output_dir.mkdir(parents=True, exist_ok=True) - storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=1) - DCP.save(state_dict, storage_writer=storage_writer, no_dist=True) + storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8) + + DCP.save({"model": state_dict}, storage_writer=storage_writer) tokenizer_dir = output_dir / "tokenizer" tokenizer_dir.mkdir(parents=True, exist_ok=True) tok.save_pretrained(tokenizer_dir) if __name__ == "__main__": - init_logger() - parser = argparse.ArgumentParser( - description="Convert DeepSeek V3 weights to DCP format." - ) + parser = argparse.ArgumentParser(description="Convert Mistral3 weights to DCP format.") + parser.add_argument("mistral_model", type=Path, help="HF Model in Mistral3 format") + parser.add_argument("output_dir", type=Path, help="Output directory for DCP.") parser.add_argument( - "deepseek_model", type=str, help="HF Model in DeepSeek V3 format" + "--max_seq_len", + type=int, + default=131072, + help="The maximum sequence length of the model.", ) - parser.add_argument("output_dir", type=Path, help="Output directory for DCP.") args = parser.parse_args() - convert_deepseekv3_weights(args.deepseek_model, args.output_dir) + convert_mistral3_weights( + args.mistral_model, args.output_dir, max_seq_len=args.max_seq_len + ) diff --git a/scripts/convert_mistral_dcp_to_hf.py b/scripts/convert_mistral_dcp_to_hf.py new file mode 100644 index 0000000000..6868965a87 --- /dev/null +++ b/scripts/convert_mistral_dcp_to_hf.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from pathlib import Path + +import torch +import torch.distributed.checkpoint +import torch.distributed.checkpoint.format_utils + +from torchtitan.tools.logging import init_logger, logger + +from transformers import AutoTokenizer, AutoConfig, Mistral3ForConditionalGeneration# noqa F401 + + + + + +def permute_1d(w: torch.Tensor, n_heads: int, dim1: int) -> torch.Tensor: + return w.view(n_heads, dim1 // n_heads // 2, 2).transpose(1, 2).reshape(dim1) + + + + +# permute for sliced rotary +def permute(w, n_heads, dim1, dim2): + print("n_heads ", n_heads) + print("dim1 ", dim1) + print("dim2 ", dim2) + return ( + w.view(n_heads, dim1 // n_heads // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + +# And reversed +def reverse_permute(w, n_heads, dim1, dim2): + print("n_heads ", n_heads) + print("dim1 ", dim1) + print("dim2 ", dim2) + return ( + w.view(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + +def param_to_hf_processing( + name: str, + param: torch.Tensor, + num_heads: int, + hidden_size: int, + kv_dim: int, + num_kv_heads: int, +): + # language model layers + if name.startswith("language_model.layers."): + # Example: language_model.layers.{L}.ln_attn.weight + if name.endswith(".ln_attn.weight"): + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".ln_attn.weight", ".input_layernorm.weight") + return out_name, param + if name.endswith(".ln_mlp.weight"): + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".ln_mlp.weight", ".post_attention_layernorm.weight") + return out_name, param + + # Attention projections + if ".attn.wq.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".attn.wq.weight", ".self_attn.q_proj.weight") + #out_param = permute(param.detach(), num_heads, 4096, hidden_size) + out_param = reverse_permute(param.detach(), 32, 4096, 5120) + return out_name, out_param + if ".attn.wk.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".attn.wk.weight", ".self_attn.k_proj.weight") + #out_param = permute(param.detach(), num_kv_heads, kv_dim, hidden_size) + out_param = reverse_permute(param.detach(), 8, 1024, 5120) + return out_name, out_param + if ".attn.wv.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".attn.wv.weight", ".self_attn.v_proj.weight") + return out_name, param + if ".attn.wo.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".attn.wo.weight", ".self_attn.o_proj.weight") + return out_name, param + + # Optional q_norm / k_norm 1D weights (not present in current Mistral3 export, but safe to handle) + if ".attn.q_norm.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".attn.q_norm.weight", ".self_attn.q_norm.weight") + out_param = permute_1d(param.detach(), 1, param.shape[0]) + return out_name, out_param + if ".attn.k_norm.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".attn.k_norm.weight", ".self_attn.k_norm.weight") + out_param = permute_1d(param.detach(), 1, param.shape[0]) + return out_name, out_param + + # MLP + if ".mlp.w1.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".mlp.w1.weight", ".mlp.gate_proj.weight") + return out_name, param + if ".mlp.w2.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".mlp.w2.weight", ".mlp.down_proj.weight") + return out_name, param + if ".mlp.w3.weight" in name: + out_name = name.replace( + "language_model.layers.", "model.language_model.layers." + ).replace(".mlp.w3.weight", ".mlp.up_proj.weight") + return out_name, param + + # language model top-level components + if name == "language_model.norm.weight": + return "model.language_model.norm.weight", param + if name == "language_model.tok_embeddings.weight": + return "model.language_model.embed_tokens.weight", param + if name == "language_model.output.weight": + return "lm_head.weight", param + + # Vision tower and multimodal projector: keep the same keys (HF uses the same prefix here in our export) + if name.startswith("vision_tower."): + return name.replace("vision_tower.", "model.vision_tower."), param + if name.startswith("multi_modal_projector."): + return name.replace("multi_modal_projector.", "model.multi_modal_projector."), param + + # Unknown/unhandled: return None to skip + return None, None + + +@torch.inference_mode() +def convert_mistral_weights(input_dir: Path, output_dir: Path, model_base: str, tokenizer: str | None) -> None: + hf_model = Mistral3ForConditionalGeneration.from_pretrained(model_base) + tok = AutoTokenizer.from_pretrained(tokenizer if tokenizer else model_base) + config: AutoConfig = hf_model.config + + # Mistral3 models store text config under text_config; fall back to top-level for generic models + text_cfg = getattr(config, "text_config", config) + hidden_size: int = int(text_cfg.hidden_size) + num_heads: int = int(text_cfg.num_attention_heads) + num_kv_heads: int = int(getattr(text_cfg, "num_key_value_heads", num_heads)) + dims_per_head: int = 128 #hidden_size // num_heads + kv_dim: int = dims_per_head * num_kv_heads + + hf_state_dict = hf_model.state_dict() + + logger.info(f"Loading TorchTitan Mistral DCP weights from {input_dir}") + sd: dict[str, torch.Tensor] = {} + torch.distributed.checkpoint.format_utils._load_state_dict( + sd, + torch.distributed.checkpoint.filesystem.FileSystemReader(input_dir), + planner=torch.distributed.checkpoint.format_utils._EmptyStateDictLoadPlanner(), + no_dist=True, + ) + + # Some checkpoints might nest parameters under 'model' + if "model" in sd: + sd = sd["model"] + + skipped = {"language_model.freqs_cis", "train_state", "optimizer", "dataloader", "lr_scheduler"} + + for name, param in sd.items(): + if any(name == s or name.startswith(s + ".") for s in skipped): + continue + out_name, out_param = param_to_hf_processing( + name, param, num_heads, hidden_size, kv_dim, num_kv_heads + ) + if out_name is None: + logger.debug(f"Skipping unrecognized key: {name}") + continue + hf_state_dict[out_name] = out_param + logger.info(f"Converted {name} -> {out_name}") + + # Load updated weights into the HF model + + print("my state dict", hf_state_dict.keys()) + print("model state dict", hf_model.state_dict().keys()) + hf_model.load_state_dict(hf_state_dict) + + # Save in bf16 + hf_model = hf_model.to(dtype=torch.bfloat16) + hf_model.save_pretrained(output_dir) + tok.save_pretrained(output_dir) + + +if __name__ == "__main__": + init_logger() + parser = argparse.ArgumentParser(description="Convert TorchTitan Mistral DCP weights to HF format.") + parser.add_argument("input_dir", type=Path, help="Input directory containing DCP checkpoint") + parser.add_argument("output_dir", type=Path, help="Output directory for HF model") + parser.add_argument("mistral_model", type=str, help="Base HF model to load config and architecture from") + parser.add_argument("--tokenizer", type=str, help="Optional tokenizer source; defaults to base model", default=None) + args = parser.parse_args() + + convert_mistral_weights(args.input_dir, args.output_dir, args.mistral_model, args.tokenizer) + + diff --git a/scripts/convert_mistral_hf_to_dcp.py b/scripts/convert_mistral_hf_to_dcp.py new file mode 100644 index 0000000000..23d30aa75a --- /dev/null +++ b/scripts/convert_mistral_hf_to_dcp.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from pathlib import Path + +import torch +import torch.distributed.checkpoint as DCP + +from torchtitan.tools.logging import logger +from torchtitan.models.mistral3.model.model import precompute_freqs_cis + +from transformers import AutoTokenizer, Mistral3Config + + +# permute for sliced rotary +def permute(w, n_heads, dim1, dim2): + return ( + w.view(n_heads, dim1 // n_heads // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + +# And reversed +def reverse_permute(w, n_heads, dim1, dim2): + return ( + w.view(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + +@torch.inference_mode() +def convert_mistral3_weights(mistral_model, output_dir, max_seq_len: int): + # Loading the model directly might be too large, so we'll use safetensors to load the weights + from safetensors import safe_open + import os + import json + + # Load the config + config_path = os.path.join(mistral_model, "config.json") + with open(config_path, "r") as f: + config_dict = json.load(f) + + config = Mistral3Config.from_dict(config_dict) + tok = AutoTokenizer.from_pretrained(mistral_model) + + # Find all safetensors files + index_path = os.path.join(mistral_model, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + all_files = set(weight_map.values()) + else: + # If no index, look for a single safetensors file + safetensors_files = [f for f in os.listdir(mistral_model) if f.endswith(".safetensors")] + if len(safetensors_files) == 1: + all_files = [safetensors_files[0]] + else: + raise ValueError("Multiple safetensors files found without an index file") + + # Extract language model parameters from the text config + text_config = config.text_config + n_layers = text_config.num_hidden_layers + n_heads = text_config.num_attention_heads + dim = text_config.hidden_size + dims_per_head = dim // n_heads + + logger.info(f"Loading original Mistral3 weights from {mistral_model}") + + state_dict = {} + n_heads_per_shard = n_heads + num_key_value_heads = text_config.num_key_value_heads + n_kv_heads_per_shard = num_key_value_heads + key_value_dim = dims_per_head * num_key_value_heads + + # Load and process weights + hf_state_dict = {} + for filename in all_files: + filepath = os.path.join(mistral_model, filename) + with safe_open(filepath, framework="pt") as f: + for key in f.keys(): + hf_state_dict[key] = f.get_tensor(key) + + # Process language model layers + for layer in range(n_layers): + # Map from HF to torchtitan structure + # Based on the keys we found in the checkpoint + + # Norm layers + state_dict[f"language_model.layers.{layer}.ln_attn.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.input_layernorm.weight" + ] + state_dict[f"language_model.layers.{layer}.ln_mlp.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.post_attention_layernorm.weight" + ] + + # Attention layers + for wn, hn, nh in [ + ("wq", "q_proj", n_heads_per_shard), + ("wk", "k_proj", n_kv_heads_per_shard), + ("wv", "v_proj", n_kv_heads_per_shard), + ]: + if wn != "wv": + # Need to reverse the permutation for sliced rotary + + state_dict[f"language_model.layers.{layer}.attn.{wn}.weight"] = reverse_permute( + hf_state_dict[f"language_model.model.layers.{layer}.self_attn.{hn}.weight"], + n_heads if wn == "wq" else num_key_value_heads, + dim1=4096 if wn == "wq" else int(key_value_dim*0.8), + dim2=dim, + ) + else: + + state_dict[f"language_model.layers.{layer}.attn.{wn}.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.self_attn.{hn}.weight" + ] + + state_dict[f"language_model.layers.{layer}.attn.wo.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.self_attn.o_proj.weight" + ] + + # Feed-forward layers + state_dict[f"language_model.layers.{layer}.mlp.w1.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.gate_proj.weight" + ] + state_dict[f"language_model.layers.{layer}.mlp.w2.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.down_proj.weight" + ] + state_dict[f"language_model.layers.{layer}.mlp.w3.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.up_proj.weight" + ] + + # Language model norm and embeddings + state_dict["language_model.norm.weight"] = hf_state_dict["language_model.model.norm.weight"] + + # Handling embeddings + state_dict["language_model.tok_embeddings.weight"] = hf_state_dict["language_model.model.embed_tokens.weight"] + # If fusion embedding exists in the HF model + + state_dict["language_model.output.weight"] = hf_state_dict["language_model.lm_head.weight"] + + # Vision tower components + if "vision_tower.ln_pre.weight" in hf_state_dict: + # Copy over vision tower weights, restructuring to put them under model.vision_encoder.pixtral_vision + vision_keys = [k for k in hf_state_dict.keys() if k.startswith("vision_tower.")] + for key in vision_keys: + state_dict[key] = hf_state_dict[key] + # # Replace vision_tower with vision_encoder.pixtral_vision in the key path + # new_key = key.replace("vision_tower", "vision_encoder.pixtral_vision") + # state_dict[new_key] = hf_state_dict[key] + #state_dict[''] + + # Multi-modal projector + mm_keys = [k for k in hf_state_dict.keys() if k.startswith("multi_modal_projector.")] + for key in mm_keys: + state_dict[ key] = hf_state_dict[key] + + # TODO figure out how to not hardcode + dims_per_head = 128 + + # NOTE: precompute freqs_cis because must be persisted by default in torchtitan + state_dict["language_model.freqs_cis"] = precompute_freqs_cis( + dims_per_head, + max_seq_len, + text_config.rope_theta, + ) + + print(state_dict.keys()) + + logger.info(f"Writing to DCP at '{output_dir}'") + output_dir.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8) + + DCP.save(state_dict, storage_writer=storage_writer) + tokenizer_dir = output_dir / "tokenizer" + tokenizer_dir.mkdir(parents=True, exist_ok=True) + tok.save_pretrained(tokenizer_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mistral3 weights to DCP format.") + parser.add_argument("mistral_model", type=Path, help="HF Model in Mistral3 format") + parser.add_argument("output_dir", type=Path, help="Output directory for DCP.") + parser.add_argument( + "--max_seq_len", + type=int, + default=131072, + help="The maximum sequence length of the model.", + ) + args = parser.parse_args() + + convert_mistral3_weights( + args.mistral_model, args.output_dir, max_seq_len=args.max_seq_len + ) diff --git a/scripts/convert_mistral_hf_to_dcp_2.py b/scripts/convert_mistral_hf_to_dcp_2.py new file mode 100644 index 0000000000..ef544345ff --- /dev/null +++ b/scripts/convert_mistral_hf_to_dcp_2.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from pathlib import Path + +import torch +import torch.distributed.checkpoint as DCP + +from torchtitan.tools.logging import logger +from torchtitan.models.mistral3.model.model import precompute_freqs_cis + +from transformers import AutoTokenizer, Mistral3Config + + +# permute for sliced rotary +def permute(w, n_heads, dim1, dim2): + return ( + w.view(n_heads, dim1 // n_heads // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + +# And reversed +def reverse_permute(w, n_heads, dim1, dim2): + return ( + w.view(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + +@torch.inference_mode() +def convert_mistral3_weights(mistral_model, output_dir, max_seq_len: int): + # Loading the model directly might be too large, so we'll use safetensors to load the weights + from safetensors import safe_open + import os + import json + + # Load the config + config_path = os.path.join(mistral_model, "config.json") + with open(config_path, "r") as f: + config_dict = json.load(f) + + config = Mistral3Config.from_dict(config_dict) + tok = AutoTokenizer.from_pretrained(mistral_model) + + # Find all safetensors files + index_path = os.path.join(mistral_model, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + all_files = set(weight_map.values()) + else: + # If no index, look for a single safetensors file + safetensors_files = [f for f in os.listdir(mistral_model) if f.endswith(".safetensors")] + if len(safetensors_files) == 1: + all_files = [safetensors_files[0]] + else: + raise ValueError("Multiple safetensors files found without an index file") + + # Extract language model parameters from the text config + text_config = config.text_config + n_layers = text_config.num_hidden_layers + n_heads = text_config.num_attention_heads + dim = text_config.hidden_size + dims_per_head = dim // n_heads + + logger.info(f"Loading original Mistral3 weights from {mistral_model}") + + state_dict = {} + n_heads_per_shard = n_heads + num_key_value_heads = text_config.num_key_value_heads + n_kv_heads_per_shard = num_key_value_heads + key_value_dim = dims_per_head * num_key_value_heads + + # Load and process weights + hf_state_dict = {} + for filename in all_files: + filepath = os.path.join(mistral_model, filename) + with safe_open(filepath, framework="pt") as f: + for key in f.keys(): + hf_state_dict[key] = f.get_tensor(key) + + # Process language model layers + for layer in range(n_layers): + # Map from HF to torchtitan structure + # Based on the keys we found in the checkpoint + + # Norm layers + state_dict[f"language_model.layers.{layer}.ln_attn.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.input_layernorm.weight" + ] + state_dict[f"language_model.layers.{layer}.ln_mlp.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.post_attention_layernorm.weight" + ] + + # Attention layers + for wn, hn, nh in [ + ("wq", "q_proj", n_heads_per_shard), + ("wk", "k_proj", n_kv_heads_per_shard), + ("wv", "v_proj", n_kv_heads_per_shard), + ]: + if wn != "wv": + # Need to reverse the permutation for sliced rotary + + state_dict[f"language_model.layers.{layer}.attn.{wn}.weight"] = reverse_permute( + hf_state_dict[f"language_model.model.layers.{layer}.self_attn.{hn}.weight"], + n_heads if wn == "wq" else num_key_value_heads, + dim1=4096 if wn == "wq" else int(key_value_dim*0.8), + dim2=dim, + ) + else: + + state_dict[f"language_model.layers.{layer}.attn.{wn}.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.self_attn.{hn}.weight" + ] + + state_dict[f"language_model.layers.{layer}.attn.wo.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.self_attn.o_proj.weight" + ] + + # Feed-forward layers + state_dict[f"language_model.layers.{layer}.mlp.w1.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.gate_proj.weight" + ] + state_dict[f"language_model.layers.{layer}.mlp.w2.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.down_proj.weight" + ] + state_dict[f"language_model.layers.{layer}.mlp.w3.weight"] = hf_state_dict[ + f"language_model.model.layers.{layer}.mlp.up_proj.weight" + ] + + # Language model norm and embeddings + state_dict["language_model.norm.weight"] = hf_state_dict["language_model.model.norm.weight"] + + # Handling embeddings + state_dict["language_model.tok_embeddings.weight"] = hf_state_dict["language_model.model.embed_tokens.weight"] + # If fusion embedding exists in the HF model + + state_dict["language_model.output.weight"] = hf_state_dict["language_model.lm_head.weight"] + + + + # Vision tower components + if "vision_tower.ln_pre.weight" in hf_state_dict: + # Copy over vision tower weights, restructuring to put them under model.vision_encoder.pixtral_vision + vision_keys = [k for k in hf_state_dict.keys() if k.startswith("vision_tower.")] + for key in vision_keys: + state_dict[key] = hf_state_dict[key] + # # Replace vision_tower with vision_encoder.pixtral_vision in the key path + # new_key = key.replace("vision_tower", "vision_encoder.pixtral_vision") + # state_dict[new_key] = hf_state_dict[key] + #state_dict[''] + + # Multi-modal projector + mm_keys = [k for k in hf_state_dict.keys() if k.startswith("multi_modal_projector.")] + for key in mm_keys: + state_dict[ key] = hf_state_dict[key] + + # TODO figure out how to not hardcode + dims_per_head = 128 + + # NOTE: precompute freqs_cis because must be persisted by default in torchtitan + state_dict["language_model.freqs_cis"] = precompute_freqs_cis( + dims_per_head, + max_seq_len, + text_config.rope_theta, + ) + + # replace all language_model. with model.language_model. + new_state_dict = {} + for key, value in state_dict.items(): + if "language_model." in key: + new_key = key.replace("language_model.", "model.language_model.") + new_state_dict[new_key] = value + elif "vision_tower." in key: + new_key = key.replace("vision_tower.", "model.vision_tower.") + new_state_dict[new_key] = value + elif "multi_modal_projector." in key: + new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.") + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + + state_dict = new_state_dict + + + logger.info(f"Writing to DCP at '{output_dir}'") + output_dir.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(output_dir, thread_count=8) + + DCP.save(state_dict, storage_writer=storage_writer) + tokenizer_dir = output_dir / "tokenizer" + tokenizer_dir.mkdir(parents=True, exist_ok=True) + tok.save_pretrained(tokenizer_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mistral3 weights to DCP format.") + parser.add_argument("mistral_model", type=Path, help="HF Model in Mistral3 format") + parser.add_argument("output_dir", type=Path, help="Output directory for DCP.") + parser.add_argument( + "--max_seq_len", + type=int, + default=131072, + help="The maximum sequence length of the model.", + ) + args = parser.parse_args() + + convert_mistral3_weights( + args.mistral_model, args.output_dir, max_seq_len=args.max_seq_len + ) diff --git a/scripts/generate/__pycache__/_generation.cpython-311.pyc b/scripts/generate/__pycache__/_generation.cpython-311.pyc new file mode 100644 index 0000000000..405d0d5c98 Binary files /dev/null and b/scripts/generate/__pycache__/_generation.cpython-311.pyc differ diff --git a/scripts/generate/__pycache__/_vision_generation.cpython-311.pyc b/scripts/generate/__pycache__/_vision_generation.cpython-311.pyc new file mode 100644 index 0000000000..bc519c0016 Binary files /dev/null and b/scripts/generate/__pycache__/_vision_generation.cpython-311.pyc differ diff --git a/scripts/generate/__pycache__/test_generate.cpython-311.pyc b/scripts/generate/__pycache__/test_generate.cpython-311.pyc new file mode 100644 index 0000000000..8a23057627 Binary files /dev/null and b/scripts/generate/__pycache__/test_generate.cpython-311.pyc differ diff --git a/scripts/generate/_generation.py b/scripts/generate/_generation.py index 6cd3a844d7..09f6bf238b 100644 --- a/scripts/generate/_generation.py +++ b/scripts/generate/_generation.py @@ -77,4 +77,4 @@ def generate( generated_tokens = torch.cat([generated_tokens, next_token], dim=1) - return generated_tokens + return generated_tokens \ No newline at end of file diff --git a/scripts/generate/_vision_generation.py b/scripts/generate/_vision_generation.py new file mode 100644 index 0000000000..0c79911fcb --- /dev/null +++ b/scripts/generate/_vision_generation.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + + +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + + +def logits_to_probs( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def generate_next_token( + model, + x: torch.Tensor, + *, + images: Optional[list] = None, + image_features: Optional[torch.Tensor] = None, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + logits = model(x, images=images, image_features=image_features) # (B, T, vocab_size) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) + return next_token + + +@torch.no_grad() +def generate( + model, + input_ids: torch.Tensor, + *, + images: Optional[torch.Tensor] = None, + image_features: Optional[torch.Tensor] = None, + max_new_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + seed: Optional[int] = None, +) -> torch.Tensor: + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + + generated_tokens = input_ids.clone() + + for _ in range(max_new_tokens): + next_token = generate_next_token( + model, + images=images, + image_features=image_features, + x=generated_tokens, + temperature=temperature, + top_k=top_k, + rng=rng, + ) + + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) + + return generated_tokens diff --git a/scripts/generate/run_llama_generate.sh b/scripts/generate/run_llama_generate.sh index 49b70535dd..dbd6f1b0e6 100755 --- a/scripts/generate/run_llama_generate.sh +++ b/scripts/generate/run_llama_generate.sh @@ -41,4 +41,4 @@ torchrun --standalone \ --config="${CONFIG_FILE}" \ --checkpoint="${CHECKPOINT_DIR}" \ --prompt="${PROMPT}" \ - "${overrides[@]}" + "${overrides[@]}" \ No newline at end of file diff --git a/scripts/generate/run_vision_mistral_generate.sh b/scripts/generate/run_vision_mistral_generate.sh new file mode 100755 index 0000000000..7bef6322a2 --- /dev/null +++ b/scripts/generate/run_vision_mistral_generate.sh @@ -0,0 +1,44 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_llama_generate.sh +NGPU=${NGPU:-"1"} +LOG_RANK=${LOG_RANK:-0} +CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} +CHECKPOINT_DIR=${CHECKPOINT_DIR:-"./outputs/checkpoint/"} +PROMPT=${PROMPT:-""} + +overrides=() +if [ $# -ne 0 ]; then + for arg in "$@"; do + # special case to handle prompt in quotes + if [[ "$arg" == --prompt=* ]]; then + PROMPT="${arg#--prompt=}" + # check if file + if [[ -f "$PROMPT" ]]; then + PROMPT=$(<"$PROMPT") + fi + else + # handle other args + overrides+=("$arg") + fi + done +fi + +set -x +torchrun --standalone \ + --nproc_per_node="${NGPU}" \ + --local-ranks-filter="${LOG_RANK}" \ + scripts/generate/test_vision_generate.py \ + --config="${CONFIG_FILE}" \ + --checkpoint="${CHECKPOINT_DIR}" \ + --prompt="${PROMPT}" \ + "${overrides[@]}" diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index 9207a01416..a55a80219e 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -315,4 +315,4 @@ def test_generate( ) if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() + torch.distributed.destroy_process_group() \ No newline at end of file diff --git a/scripts/generate/test_vision_generate.py b/scripts/generate/test_vision_generate.py new file mode 100644 index 0000000000..066944ebec --- /dev/null +++ b/scripts/generate/test_vision_generate.py @@ -0,0 +1,367 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +from typing import Optional + +import torch +import torch.distributed.checkpoint as dcp +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._tensor import Replicate +from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + RowwiseParallel, +) +from transformers import AutoProcessor, AutoModelForImageTextToText + +from torchtitan.tools import utils + + + +from torchtitan.components.checkpoint import excluded_parameters_for_model_only +from torchtitan.components.metrics import build_device_memory_monitor +from torchtitan.config_manager import ConfigManager +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.protocols.train_spec import get_train_spec +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.utils import device_module, device_type + +from transformers import AutoProcessor + +from PIL import Image +import requests + +# support running w/o installing as package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate._vision_generation import generate + + +def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel(input_layouts=Replicate()), + "output": ColwiseParallel(output_layouts=Replicate()), + }, + ) + + + for _, transformer_block in model.language_model.layers.items(): + layer_plan = { + "attn.wq": ColwiseParallel(), + "attn.wk": ColwiseParallel(), + "attn.wv": ColwiseParallel(), + "attn.wo": RowwiseParallel(), + "mlp.w1": ColwiseParallel(), + "mlp.w2": RowwiseParallel(), + "mlp.w3": ColwiseParallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + +@record +def test_generate( + config_path: str, + checkpoint_path: str, + prompt: str, + *, + temperature: float = 1.0, + max_new_tokens: int = 32, + batch_size: int = 1, + top_k: Optional[int] = None, + seed: Optional[int] = None, + deterministic: bool = False, +): + init_logger() + color = utils.Color + + # Load configuration from toml file + config_manager = ConfigManager() + config = config_manager.parse_args([f"--job.config_file={config_path}"]) + + + if len(args.prompt) == 0: + logger.warning( + "The input prompt is empty, model will respond from a empty sequence." + ) + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + device = torch.device(f"{device_type}:{local_rank}") + device_module.set_device(device) + device_memory_monitor = build_device_memory_monitor() + + train_spec = get_train_spec(config.model.name) + + logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") + + # Tokenizer setup + + + #tokenizer = train_spec.tokenizer_cls(config.model.tokenizer_path) + tokenizer = train_spec.build_tokenizer_fn(config) + + model_args = train_spec.model_args[config.model.flavor] + model_args.update_from_config(config, tokenizer) + + init_device = "meta" if world_size > 1 else device + with torch.device(init_device): + logger.info(f"Init model on init_device: {init_device}") + model = train_spec.model_cls(model_args) + + world_mesh = None + # Init distributed env + if world_size > 1: + dist_utils.init_distributed(config) + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=-1, + cp=1, + tp=world_size, + pp=1, + ep=1, + world_size=world_size + ) + # Build world mesh for parallelism + world_mesh = parallel_dims.world_mesh + + # apply_tp (with Sequence Parallel) on unevenly sharded + # sequences would require https://github.com/pytorch/torchtitan/pull/686 + apply_tp_minus_sp(model, world_mesh["tp"]) + + dist_utils.set_determinism(world_mesh, device, seed, deterministic) + + # materalize model + model.to_empty(device=device_type) + #with torch.no_grad(): + # model.init_weights() + #model.eval() + + + state_dict = model.state_dict() + + #state_dict = {"model": model.state_dict()} + + # Checkpoint Loading + begin = time.monotonic() + logger.info(f"Loading chkpt at: {checkpoint_path}") + + dcp.load(state_dict, checkpoint_id=checkpoint_path) + logger.info(f"Finished loading chkpt in {time.monotonic() - begin:.2f} seconds.") + + device_mem_stats = device_memory_monitor.get_peak_stats() + logger.info( + f"{device_type.upper()} memory usage for model: " + f"{device_mem_stats.max_reserved_gib:.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%)" + ) + + processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}, + {"type": "text", "text": prompt}, + + ], + } + ] + + image = Image.open(requests.get(url, stream=True).raw) + + inputs = processor.apply_chat_template(messages, tokenize=True, return_dict=True, return_tensors="pt").to(device_type, dtype=torch.bfloat16) + #tokenized = tokenizer.apply_chat_template(conversation, tokenize=True, return_dict=True, return_tensors="pt") + + torch_device='cuda:0' + hf_model = AutoModelForImageTextToText.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503", device_map=torch_device, torch_dtype=torch.bfloat16) + + #print(hf_model.vision_tower.config) + #print(hf_model.vision_tower.patch_conv.weight) + #print(hf_model.multi_modal_projector.linear_1.weight) + + pixel_values = inputs["pixel_values"].to(device_type) + image_sizes = inputs["image_sizes"] + input_ids = inputs['input_ids'].to(device_type) + + #import copy + + #model.vision_tower = copy.deepcopy(model.vision_tower).to('cuda:0', dtype=torch.bfloat16) + #model.multi_modal_projector = copy.deepcopy(model.multi_modal_projector).to('cuda:0', dtype=torch.bfloat16) + + #hf_model.vision_tower = model.vision_tower.to('cuda:0', dtype=torch.bfloat16) + #hf_model.multi_modal_projector = model.multi_modal_projector.to('cuda:0', dtype=torch.bfloat16) + + #model.vision_tower = hf_model.vision_tower + #model.multi_modal_projector = hf_model.multi_modal_projector + + #image_features = [model.get_image_features(pixel_values=pixel_values, vision_feature_layer=-1, image_sizes=image_sizes)[0].to('cuda:0', dtype=torch.float32)] + + #print(image_features[0].dtype) + #exit(0) + + #print(f"original image sizes: {image_sizes}") + #print(f"original pixel values: {pixel_values.shape}") + + images = ([image],) + #images = set() + + device_memory_monitor.reset_peak_stats() + + # Run generation + t0 = time.monotonic() + responses = generate( + model, + input_ids, + temperature=temperature, + max_new_tokens=max_new_tokens, + #image_features=image_features, + #images=(pixel_values, image_sizes), + images=images, + top_k=top_k, + seed=seed, + ) + t1 = time.monotonic() + elapsed_sec = t1 - t0 + + # Post process + B, T = responses.size() # B: batch_size, T: total seq length + input_n_tokens = input_ids.size(1) + generated_n_tokens = T - input_n_tokens # == max_new_tokens + + if local_rank == 0: + logger.info(f"Generation completed in {elapsed_sec:.2f} seconds.") + + r, b = color.red, color.blue + + output_data = { + "metadata": {}, + "responses": [], + } + + for i, tokens in enumerate(responses): + inp_tok = tokens[:input_n_tokens].tolist() + out_tok = tokens[input_n_tokens:].tolist() + + #input_text = tokenizer.decode(inp_tok) + #output_text = tokenizer.decode(out_tok) + input_text = processor.tokenizer.decode(inp_tok) + output_text = processor.tokenizer.decode(out_tok) + + _data = { + "response_idx": i, + "input_text": input_text, + "output_text": output_text, + } + output_data["responses"].append(_data) + + logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}") + + device_mem_stats = device_memory_monitor.get_peak_stats() + output_data["metadata"] = { + "generated_n_tokens": generated_n_tokens, + "input_n_tokens": input_n_tokens, + "generation_time_sec": elapsed_sec, + "tokens_per_sec": (B * T) / elapsed_sec, + "batch_size": B, + "seed": seed, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), + "memory/max_active(GiB)": device_mem_stats.max_active_gib, + "memory/max_active(%)": device_mem_stats.max_active_pct, + "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, + "memory/num_ooms": device_mem_stats.num_ooms, + "world_size": world_size, + "torch_version": torch.__version__, + } + + if args.out: + print(json.dumps(output_data, indent=4)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test generation") + parser.add_argument( + "--config", type=str, required=True, help="TOML config file path (required)" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Checkpoint path to load (required)", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Sampling temperature. Default is 1.0", + ) + parser.add_argument( + "--max_new_tokens", + type=int, + default=32, + help="Max number of tokens to generate. Default is 32", + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Number of samples to run in batch" + ) + parser.add_argument( + "--top_k", type=int, help="Prune to select from top_k probabilities. Optional" + ) + parser.add_argument("--seed", type=int, help="Random seed for reproducibility") + parser.add_argument( + "--deterministic", + action="store_true", + help="Use deterministic algorithms wherever possible, may be slower", + ) + + parser.add_argument("--prompt", type=str, default="", help="Input prompt") + + parser.add_argument( + "--out", + action="store_true", + default=False, + help="If specified, prints the report to stdout. Defaults to no output.", + ) + + args = parser.parse_args() + + test_generate( + config_path=args.config, + checkpoint_path=args.checkpoint, + prompt=args.prompt, + temperature=args.temperature, + max_new_tokens=args.max_new_tokens, + batch_size=args.batch_size, + top_k=args.top_k, + seed=args.seed, + deterministic=args.deterministic, + ) + + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() diff --git a/scripts/merge_subsets.py b/scripts/merge_subsets.py new file mode 100644 index 0000000000..4dd7d9451a --- /dev/null +++ b/scripts/merge_subsets.py @@ -0,0 +1,20 @@ +import re +from datasets import load_dataset, load_from_disk, concatenate_datasets, Features, Value, ClassLabel, List +from datasets.utils.info_utils import VerificationMode + + +new_features = Features({'conversations': List({'content': List({'path': Value('string'), 'text': Value('string'), 'type': Value('string')}), 'role': Value('string')})}) + + +ds1 = load_from_disk("ChartQA_Subset") +ds2 = load_from_disk("H4_Subset").cast(new_features) + + +ds3 = concatenate_datasets([ds1, ds2]) + +print(ds3[0]) + +ds3.save_to_disk("CombinedDataset") + + + diff --git a/scripts/multimodal.py b/scripts/multimodal.py new file mode 100644 index 0000000000..f3895b732c --- /dev/null +++ b/scripts/multimodal.py @@ -0,0 +1,80 @@ +from datetime import datetime, timedelta +import torch +import time + +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from huggingface_hub import hf_hub_download +from transformers import Mistral3ForConditionalGeneration + + +def load_system_prompt(repo_id: str, filename: str) -> str: + file_path = hf_hub_download(repo_id=repo_id, filename=filename) + with open(file_path, "r") as file: + system_prompt = file.read() + today = datetime.today().strftime("%Y-%m-%d") + yesterday = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d") + model_name = repo_id.split("/")[-1] + return system_prompt.format(name=model_name, today=today, yesterday=yesterday) + + +model_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" +SYSTEM_PROMPT = load_system_prompt(model_id, "SYSTEM_PROMPT.txt") + +tokenizer = MistralTokenizer.from_hf_hub(model_id) + +model = Mistral3ForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch.bfloat16 +) + +image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438" + +messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, +] + +tokenized = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=messages)) + +input_ids = torch.tensor([tokenized.tokens]) +attention_mask = torch.ones_like(input_ids) + +pixel_values = torch.tensor(tokenized.images).to(dtype=torch.bfloat16) +image_sizes = torch.tensor([[pixel_values.shape[-2], pixel_values.shape[-1]]] * len(tokenized.images)) + + +t1 = time.time() + +for i in range(10): + with torch.no_grad(): # For inference efficiency + image_features = model.get_image_features( + pixel_values=pixel_values, + image_sizes=image_sizes + ) + + print(len(image_features)) + print(image_features[0].shape) + + tensor_size_in_bytes = image_features[0].nelement() * image_features[0].element_size() + print(f"Tensor size: {tensor_size_in_bytes / 1024 / 1024} MB") + +t2 = time.time() +print(f"Time taken: {t2 - t1} seconds") diff --git a/scripts/nvidia.py b/scripts/nvidia.py new file mode 100644 index 0000000000..63979014a5 --- /dev/null +++ b/scripts/nvidia.py @@ -0,0 +1,768 @@ +""" + +post of data processing script for multimodal data +that looks like this: +[ + { + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", + "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"}, + {"type": "text", "text": "Describe this image in detail."} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The image is a bee."} + ] + } + ] + } +] +""" +import argparse +import os +import shutil +import multiprocessing +import numpy as np +import pyarrow as pa +import pyarrow.dataset as pa_ds +import random +import json +import base64 +import uuid +from PIL import Image +import io +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + +from typing import List, Optional, Tuple +from torch.nn import functional as F +from torch.utils.data import Dataset +from tqdm import tqdm +from datasets import load_dataset, Dataset as DatasetsDataset +from transformers import AutoTokenizer + +from datetime import datetime, timedelta +import torch + +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from huggingface_hub import hf_hub_download +from transformers import Mistral3ForConditionalGeneration + + +LOCAL_IMAGE_DIR = "./images" + +SCHEMA = pa.schema( + [ + pa.field("inputs", pa.large_list(pa.int32())), + pa.field("labels", pa.large_list(pa.int32())), + pa.field("position_ids", pa.large_list(pa.int32())), + pa.field("sequence_lengths", pa.large_list(pa.int64())), + pa.field("images", pa.large_list(pa.string())), + ] +) + +DATASET_INFO = r"""{ + "citation": "", + "description": "", + "features": { + "inputs": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "labels": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "position_ids": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "sequence_lengths": { + "feature": { + "dtype": "int64", + "_type": "Value" + }, + "_type": "LargeList" + } + "images": { + "feature": { + "dtype": "str", + "_type": "Value" + }, + "_type": "LargeList" + } + }, + "homepage": "", + "license": "" +}""" + + +def process_packing_shard(shard, args, tokenizer_pad_id, rank, world_size): + packer = MultimodalPackedDataset( + shard, + max_seq_len=args.pack_to_sequence_length, + padding_idx=tokenizer_pad_id, + split_across_pack=not args.chat, + show_pbar=rank == 0, + ) + + + if args.save_to_disk: + # create a schema that uses int64 for list sizes + + example = ( + { + "inputs": packer.packs[0]["inputs"], + "labels": packer.packs[0]["labels"], + "position_ids": packer.packs[0]["position_ids"], + "sequence_lengths": packer.packs[0]["sequence_lengths"], + "images": packer.packs[0]["images"], + } + if len(packer.packs) > 0 + else None + ) + + oriented_data = { + "inputs": [pack["inputs"] for pack in packer.packs], + "labels": [pack["labels"] for pack in packer.packs], + "position_ids": [pack["position_ids"] for pack in packer.packs], + "sequence_lengths": [pack["sequence_lengths"] for pack in packer.packs], + "images": [pack["images"] for pack in packer.packs], + } + pa_table = pa.Table.from_pydict(oriented_data, schema=SCHEMA) + del oriented_data + + pa_ds.write_dataset( + pa_table, + os.path.join(args.save_to_disk, str(rank)), + format="arrow", + ) + + filename = f"data-{rank:05d}-of-{world_size:05d}.arrow" + + shutil.move( + os.path.join(args.save_to_disk, str(rank), "part-0.arrow"), + os.path.join(args.save_to_disk, filename), + ) + + os.rmdir(os.path.join(args.save_to_disk, str(rank))) + else: + filename = None + + return packer.total_tokens, packer.packed_tokens, packer.dropped, filename, example + + +# https://github.com/pytorch/torchtune/blob/9d91fe39f08661952da4180b9e7fb2eba5a7a5e7/torchtune/datasets/_packed.py +class MultimodalPackedDataset(Dataset): + """ + Performs greedy sample packing on a provided dataset. This is done as a single + preprocessing step before training begins. Shuffling is done outside of this + class on packed samples with a ``Sampler`` as part of the dataloader. Currently, + this only supports in-memory map-style datasets. + + The class loads, tokenizes, and packs examples on initialization - no tokenization is done during training. + + The general flow on initialization is: load tokenized sample -> add to buffer -> + when buffer is long enough, add to ``self.packs``. + + During training, returns self.packs[idx] as input, label, attention mask, and + position ids. The attention mask is a lower triangular block mask to prevent + samples from cross-attending within a pack. The position ids indicate the position + of each token relative to its sample within a pack. These are all padded to max + sequence length, so a batch-wise collator is not needed. + + A packed sample is made up of individual smaller sequence length samples jammed together + within ``max_seq_len``. For example, if max_seq_len is 6 and there are varied + length samples:: + + tokens = [ + [S1, S1, S1, S2, S2, pad], + [S3, S3, S4, S4, pad, pad], + ..., + ] + + To prevent cross-contamination, the following mask would be returned for the + first pack in the example:: + + mask = [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ] + + The position ids would be:: + + input_pos = [ + [0, 1, 2, 0, 1, 2], + [0, 1, 0, 1, 2, 3], + ..., + ] + + The identity matrix is used in the mask for pad tokens instead of a causal mask. + For position ids for pad tokens, we simply continue to increment from the previous + sample normally. + + Args: + ds (Dataset): dataset to sample pack. This should return a dictionary with field + "tokens" and "labels" containing the tokenized and label samples. + max_seq_len (int): Maximum number of tokens to pack + padding_idx (int): padding index for the tokenizer. Default is 0. + max_packs (Optional[int]): Maximum number of packs. Default is None, which will create as many + packs as possible. + split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``, + split the sample into the next pack, or move it entirely to the beginning of the next pack. + For pre-training, typically this is set to True for general text completion. For + fine-tuning, typically this is set to False to avoid truncating sentences in instruct + tuning. Default is False. + """ + + def __init__( + self, + ds: Dataset, + *, + max_seq_len: int, + padding_idx: int = 0, + max_packs: Optional[int] = None, + split_across_pack: bool = False, + group_size: int = 5000, + show_pbar=True, + ) -> None: + self.ds = ds + self.max_seq_len = max_seq_len + self.padding_idx = padding_idx + self.max_packs = max_packs + self.split_across_pack = split_across_pack + self.packs = [] + self.previous_sample_boundary: int = 0 + self.packed_tokens: int = 0 + self.total_tokens: int = 0 + self.dropped: int = 0 + self.show_pbar = show_pbar + self.group_size = group_size + if split_across_pack: + self._pack_greedy() + else: + self._pack_ffd() + + def _get_empty_pack(self): + + return { + "inputs": np.empty(0, dtype=np.int32), + "labels": np.empty(0, dtype=np.int32), + "position_ids": np.empty(0, dtype=np.int32), + "sequence_lengths": [], + "images": [], + } + + def _pack_ffd(self) -> None: + ds_iterator = iter(self.ds) + finished_iterating = False + + pbar = ( + tqdm( + total=len(self.ds), + desc="Packing dataset (FFD)", + dynamic_ncols=True, + ) + if self.show_pbar + else None + ) + + while not finished_iterating: + # 1. Fetch a large group of samples into memory. + group = [] + try: + for _ in range(self.group_size): + sample = next(ds_iterator) + seq_len = len(sample["inputs"]) + + + if seq_len > self.max_seq_len: + self.dropped += 1 + continue + # Store sample and its length for sorting + group.append({"sample": sample, "seq_len": seq_len}) + except StopIteration: + finished_iterating = True + + if not group: + break + + + # 2. Sort the group by length in descending order (the "Decreasing" part of FFD). + group.sort(key=lambda x: x["seq_len"], reverse=True) + + # 3. Pack this group using the "First-Fit" heuristic. + # Each bin holds the samples it contains and its remaining space. + bins = [] # List of {"samples": [], "remaining_space": max_seq_len} + + for item in group: + placed = False + # Try to place the item in the first available bin. + for bin in bins: + if bin["remaining_space"] >= item["seq_len"]: + bin["samples"].append(item["sample"]) + bin["remaining_space"] -= item["seq_len"] + placed = True + break + + # If no existing bin could accommodate the item, create a new one. + if not placed: + bins.append( + { + "samples": [item["sample"]], + "remaining_space": self.max_seq_len - item["seq_len"], + } + ) + + + # 4. Convert the completed bins from this group into final, padded packs. + for bin_info in bins: + if self._should_stop_packing(): + break + + current_pack = self._get_empty_pack() + for sample in bin_info["samples"]: + tokens = np.array(sample["inputs"], dtype=np.int32) + labels = np.array(sample["labels"], dtype=np.int32) + images = sample["images"] + seq_len = len(tokens) + + + current_pack["inputs"] = np.concatenate( + (current_pack["inputs"], tokens) + ) + current_pack["labels"] = np.concatenate( + (current_pack["labels"], labels) + ) + current_pack["position_ids"] = np.concatenate( + ( + current_pack["position_ids"], + np.arange(seq_len, dtype=np.int32), + ) + ) + current_pack["sequence_lengths"].append(seq_len) + current_pack["images"].append(images) + + + self._add_pack(current_pack) + + + if pbar: + pbar.update(len(group)) + + if self._should_stop_packing(): + # Ensure the outer loop breaks if max_packs is reached. + break + + if pbar: + # Manually set pbar to total to show 100% at the end + pbar.n = pbar.total + pbar.refresh() + pbar.close() + + def _pack_greedy(self) -> None: + """Iterate through the dataset. Use a buffer to hold samples until max_seq_len, + then append the buffer to self.packs as a single "packed" sample. Continue + until max_packs or end of dataset.""" + + current_pack = self._get_empty_pack() + + pbar = ( + tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True) + if self.show_pbar + else None + ) + + for sample in self.ds: + tokens = np.array(sample["inputs"], dtype=np.int32) + labels = np.array(sample["labels"], dtype=np.int32) + + if seq_len > self.max_seq_len and not self.split_across_pack: + # print( + # f"Dropping sample that is too long ({seq_len} > {self.max_seq_len})" + # ) + self.dropped += 1 + continue + + current_pack["inputs"] = np.concatenate((current_pack["inputs"], tokens)) + current_pack["labels"] = np.concatenate((current_pack["labels"], labels)) + + position_ids = np.arange(seq_len, dtype=np.int32) + current_pack["position_ids"] = np.concatenate( + (current_pack["position_ids"], position_ids) + ) + + current_pack["sequence_lengths"] += [seq_len] + + while ( + len(current_pack["inputs"]) > self.max_seq_len + and not self._should_stop_packing() + ): + current_pack = self._split_and_add_pack(current_pack) + + if pbar: + pbar.update() + + self.previous_sample_boundary = len(current_pack["inputs"]) + + if self._should_stop_packing(): + break + + if len(current_pack["inputs"]) > 0 and ( + self.max_packs is None or len(self.packs) < self.max_packs + ): + self._add_pack(current_pack) + + def _should_stop_packing(self) -> bool: + """If max packs is set, stop packing when we reach that number.""" + + if self.max_packs is not None and len(self.packs) == self.max_packs: + return True + return False + + def _split_and_add_pack(self, current_pack): + """Splits the current pack at the boundary, processes it, adds it to ``self.packs`` and + returns the start of the next pack.""" + + if self.split_across_pack: + boundary = self.max_seq_len + # The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len`` + leftover_seq_len = self.max_seq_len - sum(current_pack["seq_lens"][:-1]) + seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else [] + else: + boundary = self.previous_sample_boundary + # If we aren't splitting across packs, we leave out the last sample b/c + # it will go into the next pack + seq_len_padding = [] + + pack = { + "inputs": current_pack["inputs"][:boundary], + "labels": current_pack["labels"][:boundary], + "position_ids": current_pack["position_ids"][:boundary], + "sequence_lengths": current_pack["sequence_lengths"][:-1] + seq_len_padding, + } + + self._add_pack(pack) + + # Return the length of the first sample in next pack if we are splitting across packs, + # otherwise return the length of the last sample in the current pack + next_seq_len = ( + len(current_pack["inputs"][boundary:]) + if self.split_across_pack + else current_pack["sequence_lengths"][-1] + ) + + return { + "inputs": current_pack["inputs"][boundary:], + "labels": current_pack["labels"][boundary:], + "position_ids": current_pack["position_ids"][boundary:], + "sequence_lengths": [next_seq_len], + } + + def _add_pack(self, pack) -> None: + """Processes, pads and adds a pack to ``self.packs``.""" + pack = self._pad_pack(pack, padding_idx=self.padding_idx) + self.packs.append(pack) + + def _pad_pack(self, pack, padding_idx: int): + """Pads a pack to ``self.max_seq_len``.""" + num_tokens = len(pack["inputs"]) + num_padding_tokens = self.max_seq_len - num_tokens + + self.packed_tokens += num_tokens + self.total_tokens += self.max_seq_len + + padded_inputs = np.pad( + pack["inputs"], (0, num_padding_tokens), constant_values=self.padding_idx + ) + padded_labels = np.pad( + pack["labels"], (0, num_padding_tokens), constant_values=-100 + ) + + if num_padding_tokens > 0: + # don't care much about padded position_ids, but create them for consistency + start_pos = int(pack["position_ids"][-1] + 1) if num_tokens > 0 else 0 + pad_positions = np.arange( + start_pos, start_pos + num_padding_tokens, dtype=np.int32 + ) + padded_position_ids = np.concatenate((pack["position_ids"], pad_positions)) + else: + padded_position_ids = pack["position_ids"] + + padded_seq_lens = pack["sequence_lengths"] + if num_padding_tokens > 0: + padded_seq_lens.append(num_padding_tokens) + + return { + "inputs": padded_inputs, + "labels": padded_labels, + "position_ids": padded_position_ids, + "sequence_lengths": padded_seq_lens, + "images": pack["images"], + } + + def __len__(self) -> int: + return len(self.packs) + + def __getitem__(self, idx: int) -> dict[str, np.ndarray]: + return self.packs[idx] + + +def main(args): + + from datasets import load_dataset, load_from_disk + + #dataset = load_dataset('json', data_files='/home/artem_nous/cambrian_set/output2.json')['train'].select(range(100)) + + dataset = load_from_disk("CombinedDataset").select(range(1000)) + """ + + if 'json' in args.dataset: + dataset = load_dataset('json', data_files=args.dataset)['train'] + if args.limit is not None: + dataset = dataset.select(range(args.limit)) + else: + dataset = load_dataset(args.dataset, name=args.subset, split=args.split) + """ + + + def remove_none_recursively(obj): + if isinstance(obj, dict): + return {k: remove_none_recursively(v) for k, v in obj.items() if v is not None} + elif isinstance(obj, list): + return [remove_none_recursively(item) for item in obj] + else: + return obj + + from transformers import AutoProcessor + tokenizer = AutoProcessor.from_pretrained(args.preprocessor, use_fast=True) + + + def _tokenize_chat_multimodal(sample): + inputs = [] + labels = [] + images = [] + + + for conversation in sample["conversations"]: + + image = None + conversation = remove_none_recursively(conversation) + + + + for message in conversation: + + for item in message['content']: + if item['type'] != 'text': + print(item) + image = item['path'] + + keys = list(message.keys()) + + if "from" in keys and "value" in keys: + # sharegpt format + message_from = message.pop("from") + if message_from == "gpt": + message["role"] = "assistant" + elif message_from == "human": + message["role"] = "user" + else: + message["role"] = message_from + + message["content"] = message.pop("value") + elif "role" in keys and "content" in keys: + pass + else: + raise RuntimeError(f"Unknown chat format, keys are {keys}") + + + tokenized = tokenizer.apply_chat_template(conversation, tokenize=True, return_dict=True, return_tensors="pt") + + + + tokens = tokenized["input_ids"][0] #tokenizer.apply_chat_template(conversation, tokenize=True) + current_len = 0 + label = [] + for i in range(len(conversation)): + if i + 1 == len(conversation): + next_tokens = tokenizer.apply_chat_template(conversation, + tokenize=True, return_dict=True, return_tensors="pt")["input_ids"][0][current_len:] + else: + if "assistant" == conversation[i + 1]["role"]: + next_tokens = tokenizer.apply_chat_template(conversation[: i + 1], + add_generation_prompt=True, tokenize=True, return_dict=True)["input_ids"][0][current_len:] + else: + next_tokens = tokenizer.apply_chat_template(conversation[: i + 1], + tokenize=True, return_dict=True)["input_ids"][0][current_len:] + #next_tokens = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=conversation[: i + 1])).tokens[current_len:] + + if conversation[i]["role"] == "assistant": + label.extend(next_tokens) + else: + label.extend([-100] * len(next_tokens)) + + current_len += len(next_tokens) + + inputs.append(tokens) + labels.append(label) + images.append(image) + + + return { + "inputs": inputs, + "labels": labels, + "images": images, + } + + def _tokenize_mistral_format(sample): + messages = sample["messages"] + cleaned_messages = remove_none_recursively(messages) + tokenized = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=cleaned_messages)) + return tokenized.__dict__ + + dataset = dataset.shuffle(args.seed) + + original_column_names = list(dataset.features.keys()) + + dataset = dataset.map( + _tokenize_chat_multimodal, + batched=True, + batch_size=1, + #batch_size=args.batch_size, + num_proc=128 + ) + + dataset = dataset.remove_columns(original_column_names) + #print(dataset[0]['images']) + + efficiency = 1.0 + dropped = 0 + if args.pack_to_sequence_length: + num_shards = 32 # args.num_proc + shards = [ + dataset.shard(num_shards=num_shards, index=i) for i in range(num_shards) + ] + + + with multiprocessing.Pool(processes=num_shards) as pool: + process_args = [ + (shard, args, tokenizer.tokenizer.pad_token_id, index, num_shards) + for index, shard in enumerate(shards) + ] + + results = pool.starmap(process_packing_shard, process_args) + + examples = [] + filenames = [] + total_tokens = 0 + packed_tokens = 0 + + for total, packed, dropped_, filename, example in tqdm(results): + if example: + examples.append(example) + if filename: + filenames.append(filename) + total_tokens += total + packed_tokens += packed + dropped += dropped_ + + if total_tokens > 0: + efficiency = packed_tokens / total_tokens + + example = examples[0] + + if args.save_to_disk: + with open(os.path.join(args.save_to_disk, "dataset_info.json"), "wb") as f: + f.write(DATASET_INFO.encode()) + + # verify we can open and do any conversion needed + dataset = load_dataset(args.save_to_disk, num_proc=args.num_proc) + + else: + if args.drop_larger_than: + len_before = len(dataset) + dataset = dataset.filter( + lambda x: len(x["inputs"]) <= args.drop_larger_than + ) + dropped = len_before - len(dataset) + + if args.save_to_disk: + print(f"Saving to {args.save_to_disk}") + dataset.save_to_disk(args.save_to_disk) + + example = dataset[0] + + if args.show_example: + inputs = example["inputs"] + labels = example["labels"] if "labels" in example else None + position_ids = example["position_ids"] if "position_ids" in example else None + + example_out = "" + for i in range(0, len(inputs)): + token = inputs[i] + label = labels[i] if labels is not None else token + position_id = position_ids[i] if position_ids is not None else None + + decoded = tokenizer.decode(token) + + if label == -100: + example_out += f"\033[31m{decoded}\033[0m({token}" + else: + example_out += f"\033[32m{decoded}\033[0m({token}" + + if position_id != None: + example_out += f"@{position_id})" + else: + example_out += ")" + + print(example_out) + + if dropped > 0: + print(f"Dropped {dropped} too-long samples") + print(f"Efficiency: {efficiency * 100:.1f}%") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--subset", type=str) + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--preprocessor", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=1000) + parser.add_argument("--num-proc", type=int) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--limit", type=int, required=False) + parser.add_argument("--chat", action="store_true") + parser.add_argument("--multiturn-only", action="store_true") + parser.add_argument("--pack-to-sequence-length", type=int) + parser.add_argument("--drop-larger-than", type=int) + parser.add_argument("--save-to-disk", type=str) + parser.add_argument("--show-example", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py index 2ada83bd34..040b69ebc1 100644 --- a/scripts/preprocess_data.py +++ b/scripts/preprocess_data.py @@ -195,6 +195,7 @@ def __init__( self.ds = ds self.max_seq_len = max_seq_len self.padding_idx = padding_idx + self.padding_idx = 0 self.max_packs = max_packs self.split_across_pack = split_across_pack self.packs = [] @@ -653,4 +654,4 @@ def _tokenize_chat(sample): parser.add_argument("--show-example", action="store_true") args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/scripts/preprocess_interleaved_data.py b/scripts/preprocess_interleaved_data.py new file mode 100644 index 0000000000..500b186e1c --- /dev/null +++ b/scripts/preprocess_interleaved_data.py @@ -0,0 +1,781 @@ +""" + +post of data processing script for multimodal data +that looks like this: +[ + { + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", + "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"}, + {"type": "text", "text": "Describe this image in detail."} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The image is a bee."} + ] + } + ] + } +] +""" +import argparse +import os +import shutil +import multiprocessing +import numpy as np +import pyarrow as pa +import pyarrow.dataset as pa_ds +import random +import json +import base64 +import uuid +from PIL import Image +import io +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + +from typing import List, Optional, Tuple +from torch.nn import functional as F +from torch.utils.data import Dataset +from tqdm import tqdm +from datasets import load_dataset, Dataset as DatasetsDataset +from transformers import AutoTokenizer + +from datetime import datetime, timedelta +import torch + +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from huggingface_hub import hf_hub_download +from transformers import Mistral3ForConditionalGeneration + + +LOCAL_IMAGE_DIR = "./images" + +SCHEMA = pa.schema( + [ + pa.field("inputs", pa.large_list(pa.int32())), + pa.field("labels", pa.large_list(pa.int32())), + pa.field("position_ids", pa.large_list(pa.int32())), + pa.field("sequence_lengths", pa.large_list(pa.int64())), + pa.field("images", pa.large_list(pa.string())), + ] +) + +DATASET_INFO = r"""{ + "citation": "", + "description": "", + "features": { + "inputs": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "labels": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "position_ids": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "sequence_lengths": { + "feature": { + "dtype": "int64", + "_type": "Value" + }, + "_type": "LargeList" + } + "images": { + "feature": { + "dtype": "str", + "_type": "Value" + }, + "_type": "LargeList" + } + }, + "homepage": "", + "license": "" +}""" + + +def process_packing_shard(shard, args, tokenizer_pad_id, rank, world_size): + packer = MultimodalPackedDataset( + shard, + max_seq_len=args.pack_to_sequence_length, + padding_idx=tokenizer_pad_id, + split_across_pack=not args.chat, + show_pbar=rank == 0, + ) + + + if args.save_to_disk: + # create a schema that uses int64 for list sizes + + example = ( + { + "inputs": packer.packs[0]["inputs"], + "labels": packer.packs[0]["labels"], + "position_ids": packer.packs[0]["position_ids"], + "sequence_lengths": packer.packs[0]["sequence_lengths"], + "images": packer.packs[0]["images"], + } + if len(packer.packs) > 0 + else None + ) + + oriented_data = { + "inputs": [pack["inputs"] for pack in packer.packs], + "labels": [pack["labels"] for pack in packer.packs], + "position_ids": [pack["position_ids"] for pack in packer.packs], + "sequence_lengths": [pack["sequence_lengths"] for pack in packer.packs], + "images": [pack["images"] for pack in packer.packs], + } + pa_table = pa.Table.from_pydict(oriented_data, schema=SCHEMA) + del oriented_data + + pa_ds.write_dataset( + pa_table, + os.path.join(args.save_to_disk, str(rank)), + format="arrow", + ) + + filename = f"data-{rank:05d}-of-{world_size:05d}.arrow" + + shutil.move( + os.path.join(args.save_to_disk, str(rank), "part-0.arrow"), + os.path.join(args.save_to_disk, filename), + ) + + os.rmdir(os.path.join(args.save_to_disk, str(rank))) + else: + filename = None + + return packer.total_tokens, packer.packed_tokens, packer.dropped, filename, example + + +# https://github.com/pytorch/torchtune/blob/9d91fe39f08661952da4180b9e7fb2eba5a7a5e7/torchtune/datasets/_packed.py +class MultimodalPackedDataset(Dataset): + """ + Performs greedy sample packing on a provided dataset. This is done as a single + preprocessing step before training begins. Shuffling is done outside of this + class on packed samples with a ``Sampler`` as part of the dataloader. Currently, + this only supports in-memory map-style datasets. + + The class loads, tokenizes, and packs examples on initialization - no tokenization is done during training. + + The general flow on initialization is: load tokenized sample -> add to buffer -> + when buffer is long enough, add to ``self.packs``. + + During training, returns self.packs[idx] as input, label, attention mask, and + position ids. The attention mask is a lower triangular block mask to prevent + samples from cross-attending within a pack. The position ids indicate the position + of each token relative to its sample within a pack. These are all padded to max + sequence length, so a batch-wise collator is not needed. + + A packed sample is made up of individual smaller sequence length samples jammed together + within ``max_seq_len``. For example, if max_seq_len is 6 and there are varied + length samples:: + + tokens = [ + [S1, S1, S1, S2, S2, pad], + [S3, S3, S4, S4, pad, pad], + ..., + ] + + To prevent cross-contamination, the following mask would be returned for the + first pack in the example:: + + mask = [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ] + + The position ids would be:: + + input_pos = [ + [0, 1, 2, 0, 1, 2], + [0, 1, 0, 1, 2, 3], + ..., + ] + + The identity matrix is used in the mask for pad tokens instead of a causal mask. + For position ids for pad tokens, we simply continue to increment from the previous + sample normally. + + Args: + ds (Dataset): dataset to sample pack. This should return a dictionary with field + "tokens" and "labels" containing the tokenized and label samples. + max_seq_len (int): Maximum number of tokens to pack + padding_idx (int): padding index for the tokenizer. Default is 0. + max_packs (Optional[int]): Maximum number of packs. Default is None, which will create as many + packs as possible. + split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``, + split the sample into the next pack, or move it entirely to the beginning of the next pack. + For pre-training, typically this is set to True for general text completion. For + fine-tuning, typically this is set to False to avoid truncating sentences in instruct + tuning. Default is False. + """ + + def __init__( + self, + ds: Dataset, + *, + max_seq_len: int, + padding_idx: int = 0, + max_packs: Optional[int] = None, + split_across_pack: bool = False, + group_size: int = 5000, + show_pbar=True, + ) -> None: + self.ds = ds + self.max_seq_len = max_seq_len + self.padding_idx = padding_idx + self.max_packs = max_packs + self.split_across_pack = split_across_pack + self.packs = [] + self.previous_sample_boundary: int = 0 + self.packed_tokens: int = 0 + self.total_tokens: int = 0 + self.dropped: int = 0 + self.show_pbar = show_pbar + self.group_size = group_size + if split_across_pack: + self._pack_greedy() + else: + self._pack_ffd() + + def _get_empty_pack(self): + + return { + "inputs": np.empty(0, dtype=np.int32), + "labels": np.empty(0, dtype=np.int32), + "position_ids": np.empty(0, dtype=np.int32), + "sequence_lengths": [], + "images": [], + } + + def _pack_ffd(self) -> None: + ds_iterator = iter(self.ds) + finished_iterating = False + + pbar = ( + tqdm( + total=len(self.ds), + desc="Packing dataset (FFD)", + dynamic_ncols=True, + ) + if self.show_pbar + else None + ) + + while not finished_iterating: + # 1. Fetch a large group of samples into memory. + group = [] + try: + for _ in range(self.group_size): + sample = next(ds_iterator) + seq_len = len(sample["inputs"]) + + + if seq_len > self.max_seq_len: + self.dropped += 1 + continue + # Store sample and its length for sorting + group.append({"sample": sample, "seq_len": seq_len}) + except StopIteration: + finished_iterating = True + + if not group: + break + + + # 2. Sort the group by length in descending order (the "Decreasing" part of FFD). + group.sort(key=lambda x: x["seq_len"], reverse=True) + + # 3. Pack this group using the "First-Fit" heuristic. + # Each bin holds the samples it contains and its remaining space. + bins = [] # List of {"samples": [], "remaining_space": max_seq_len} + + for item in group: + placed = False + # Try to place the item in the first available bin. + for bin in bins: + if bin["remaining_space"] >= item["seq_len"]: + bin["samples"].append(item["sample"]) + bin["remaining_space"] -= item["seq_len"] + placed = True + break + + # If no existing bin could accommodate the item, create a new one. + if not placed: + bins.append( + { + "samples": [item["sample"]], + "remaining_space": self.max_seq_len - item["seq_len"], + } + ) + + + # 4. Convert the completed bins from this group into final, padded packs. + for bin_info in bins: + if self._should_stop_packing(): + break + + current_pack = self._get_empty_pack() + for sample in bin_info["samples"]: + tokens = np.array(sample["inputs"], dtype=np.int32) + labels = np.array(sample["labels"], dtype=np.int32) + images = sample["images"] + seq_len = len(tokens) + + + current_pack["inputs"] = np.concatenate( + (current_pack["inputs"], tokens) + ) + current_pack["labels"] = np.concatenate( + (current_pack["labels"], labels) + ) + current_pack["position_ids"] = np.concatenate( + ( + current_pack["position_ids"], + np.arange(seq_len, dtype=np.int32), + ) + ) + current_pack["sequence_lengths"].append(seq_len) + current_pack["images"].append(images) + + + self._add_pack(current_pack) + + + if pbar: + pbar.update(len(group)) + + if self._should_stop_packing(): + # Ensure the outer loop breaks if max_packs is reached. + break + + if pbar: + # Manually set pbar to total to show 100% at the end + pbar.n = pbar.total + pbar.refresh() + pbar.close() + + def _pack_greedy(self) -> None: + """Iterate through the dataset. Use a buffer to hold samples until max_seq_len, + then append the buffer to self.packs as a single "packed" sample. Continue + until max_packs or end of dataset.""" + + current_pack = self._get_empty_pack() + + pbar = ( + tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True) + if self.show_pbar + else None + ) + + for sample in self.ds: + tokens = np.array(sample["inputs"], dtype=np.int32) + labels = np.array(sample["labels"], dtype=np.int32) + + if seq_len > self.max_seq_len and not self.split_across_pack: + # print( + # f"Dropping sample that is too long ({seq_len} > {self.max_seq_len})" + # ) + self.dropped += 1 + continue + + current_pack["inputs"] = np.concatenate((current_pack["inputs"], tokens)) + current_pack["labels"] = np.concatenate((current_pack["labels"], labels)) + + position_ids = np.arange(seq_len, dtype=np.int32) + current_pack["position_ids"] = np.concatenate( + (current_pack["position_ids"], position_ids) + ) + + current_pack["sequence_lengths"] += [seq_len] + + while ( + len(current_pack["inputs"]) > self.max_seq_len + and not self._should_stop_packing() + ): + current_pack = self._split_and_add_pack(current_pack) + + if pbar: + pbar.update() + + self.previous_sample_boundary = len(current_pack["inputs"]) + + if self._should_stop_packing(): + break + + if len(current_pack["inputs"]) > 0 and ( + self.max_packs is None or len(self.packs) < self.max_packs + ): + self._add_pack(current_pack) + + def _should_stop_packing(self) -> bool: + """If max packs is set, stop packing when we reach that number.""" + + if self.max_packs is not None and len(self.packs) == self.max_packs: + return True + return False + + def _split_and_add_pack(self, current_pack): + """Splits the current pack at the boundary, processes it, adds it to ``self.packs`` and + returns the start of the next pack.""" + + if self.split_across_pack: + boundary = self.max_seq_len + # The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len`` + leftover_seq_len = self.max_seq_len - sum(current_pack["seq_lens"][:-1]) + seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else [] + else: + boundary = self.previous_sample_boundary + # If we aren't splitting across packs, we leave out the last sample b/c + # it will go into the next pack + seq_len_padding = [] + + pack = { + "inputs": current_pack["inputs"][:boundary], + "labels": current_pack["labels"][:boundary], + "position_ids": current_pack["position_ids"][:boundary], + "sequence_lengths": current_pack["sequence_lengths"][:-1] + seq_len_padding, + } + + self._add_pack(pack) + + # Return the length of the first sample in next pack if we are splitting across packs, + # otherwise return the length of the last sample in the current pack + next_seq_len = ( + len(current_pack["inputs"][boundary:]) + if self.split_across_pack + else current_pack["sequence_lengths"][-1] + ) + + return { + "inputs": current_pack["inputs"][boundary:], + "labels": current_pack["labels"][boundary:], + "position_ids": current_pack["position_ids"][boundary:], + "sequence_lengths": [next_seq_len], + } + + def _add_pack(self, pack) -> None: + """Processes, pads and adds a pack to ``self.packs``.""" + pack = self._pad_pack(pack, padding_idx=self.padding_idx) + self.packs.append(pack) + + def _pad_pack(self, pack, padding_idx: int): + """Pads a pack to ``self.max_seq_len``.""" + num_tokens = len(pack["inputs"]) + num_padding_tokens = self.max_seq_len - num_tokens + + self.packed_tokens += num_tokens + self.total_tokens += self.max_seq_len + + padded_inputs = np.pad( + pack["inputs"], (0, num_padding_tokens), constant_values=self.padding_idx + ) + padded_labels = np.pad( + pack["labels"], (0, num_padding_tokens), constant_values=-100 + ) + + if num_padding_tokens > 0: + # don't care much about padded position_ids, but create them for consistency + start_pos = int(pack["position_ids"][-1] + 1) if num_tokens > 0 else 0 + pad_positions = np.arange( + start_pos, start_pos + num_padding_tokens, dtype=np.int32 + ) + padded_position_ids = np.concatenate((pack["position_ids"], pad_positions)) + else: + padded_position_ids = pack["position_ids"] + + padded_seq_lens = pack["sequence_lengths"] + if num_padding_tokens > 0: + padded_seq_lens.append(num_padding_tokens) + + return { + "inputs": padded_inputs, + "labels": padded_labels, + "position_ids": padded_position_ids, + "sequence_lengths": padded_seq_lens, + "images": pack["images"], + } + + def __len__(self) -> int: + return len(self.packs) + + def __getitem__(self, idx: int) -> dict[str, np.ndarray]: + return self.packs[idx] + + +def main(args): + + from datasets import load_dataset + + #dataset = load_dataset('json', data_files='/home/artem_nous/cambrian_set/output2.json')['train'].select(range(100)) + + if 'json' in args.dataset: + dataset = load_dataset('json', data_files=args.dataset)['train'] + if args.limit is not None: + dataset = dataset.select(range(args.limit)) + else: + dataset = load_dataset(args.dataset, name=args.subset, split=args.split) + + + def remove_none_recursively(obj): + if isinstance(obj, dict): + return {k: remove_none_recursively(v) for k, v in obj.items() if v is not None} + elif isinstance(obj, list): + return [remove_none_recursively(item) for item in obj] + else: + return obj + + from transformers import AutoProcessor + tokenizer = AutoProcessor.from_pretrained(args.preprocessor, use_fast=True) + + + def _tokenize_chat_multimodal(sample): + inputs = [] + labels = [] + images = [] + + for conversation in sample["messages"]: + + image = None + conversation = remove_none_recursively(conversation) + + for message in conversation: + + keys = list(message.keys()) + + for item in message['content']: + if 'base64' in item.keys(): + # save image in local folder as PIL image with uuid + # Decode base64 image data + image_data = base64.b64decode(item['base64']) + image = Image.open(io.BytesIO(image_data)) + + # Generate UUID4 filename + image_filename = f"{uuid.uuid4()}.jpg" + image_path = os.path.join(LOCAL_IMAGE_DIR, image_filename) + + # Ensure directory exists + os.makedirs(LOCAL_IMAGE_DIR, exist_ok=True) + + # Save image as JPG + image.save(image_path, 'JPEG') + + # remove base64 key, set type image + item.pop('base64') + item['type'] = 'image' + item['path'] = image_path + + # NOTE: possible to have multiple images in one message + images.append(image_path) + + + if "from" in keys and "value" in keys: + # sharegpt format + message_from = message.pop("from") + if message_from == "gpt": + message["role"] = "assistant" + elif message_from == "human": + message["role"] = "user" + else: + message["role"] = message_from + + message["content"] = message.pop("value") + elif "role" in keys and "content" in keys: + pass + else: + raise RuntimeError(f"Unknown chat format, keys are {keys}") + + + tokenized = tokenizer.apply_chat_template(conversation, tokenize=True, return_dict=True, return_tensors="pt") + + tokens = tokenized["input_ids"][0] #tokenizer.apply_chat_template(conversation, tokenize=True) + + # NOTE: if image is None, we keep it as None + + current_len = 0 + label = [] + for i in range(len(conversation)): + if i + 1 == len(conversation): + next_tokens = tokenizer.apply_chat_template(conversation, + tokenize=True, return_dict=True, return_tensors="pt")["input_ids"][0][current_len:] + else: + if "assistant" == conversation[i + 1]["role"]: + next_tokens = tokenizer.apply_chat_template(conversation[: i + 1], + add_generation_prompt=True, tokenize=True, return_dict=True)["input_ids"][0][current_len:] + else: + next_tokens = tokenizer.apply_chat_template(conversation[: i + 1], + tokenize=True, return_dict=True)["input_ids"][0][current_len:] + #next_tokens = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=conversation[: i + 1])).tokens[current_len:] + + if conversation[i]["role"] == "assistant": + label.extend(next_tokens) + else: + label.extend([-100] * len(next_tokens)) + + current_len += len(next_tokens) + + inputs.append(tokens) + labels.append(label) + + + return { + "inputs": inputs, + "labels": labels, + "images": images, + } + + def _tokenize_mistral_format(sample): + messages = sample["messages"] + cleaned_messages = remove_none_recursively(messages) + tokenized = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=cleaned_messages)) + return tokenized.__dict__ + + dataset = dataset.shuffle(args.seed) + + original_column_names = list(dataset.features.keys()) + + dataset = dataset.map( + _tokenize_chat_multimodal, + batched=True, + #batch_size=args.batch_size, + ) + + dataset = dataset.remove_columns(original_column_names) + #print(dataset[0]['images']) + + efficiency = 1.0 + dropped = 0 + if args.pack_to_sequence_length: + num_shards = 32 # args.num_proc + shards = [ + dataset.shard(num_shards=num_shards, index=i) for i in range(num_shards) + ] + + + with multiprocessing.Pool(processes=num_shards) as pool: + process_args = [ + (shard, args, tokenizer.tokenizer.pad_token_id, index, num_shards) + for index, shard in enumerate(shards) + ] + + results = pool.starmap(process_packing_shard, process_args) + + examples = [] + filenames = [] + total_tokens = 0 + packed_tokens = 0 + + for total, packed, dropped_, filename, example in tqdm(results): + if example: + examples.append(example) + if filename: + filenames.append(filename) + total_tokens += total + packed_tokens += packed + dropped += dropped_ + + if total_tokens > 0: + efficiency = packed_tokens / total_tokens + + example = examples[0] + + if args.save_to_disk: + with open(os.path.join(args.save_to_disk, "dataset_info.json"), "wb") as f: + f.write(DATASET_INFO.encode()) + + # verify we can open and do any conversion needed + dataset = load_dataset(args.save_to_disk, num_proc=args.num_proc) + + else: + if args.drop_larger_than: + len_before = len(dataset) + dataset = dataset.filter( + lambda x: len(x["inputs"]) <= args.drop_larger_than + ) + dropped = len_before - len(dataset) + + if args.save_to_disk: + print(f"Saving to {args.save_to_disk}") + dataset.save_to_disk(args.save_to_disk) + + example = dataset[0] + + if args.show_example: + inputs = example["inputs"] + labels = example["labels"] if "labels" in example else None + position_ids = example["position_ids"] if "position_ids" in example else None + + example_out = "" + for i in range(0, len(inputs)): + token = inputs[i] + label = labels[i] if labels is not None else token + position_id = position_ids[i] if position_ids is not None else None + + decoded = tokenizer.decode(token) + + if label == -100: + example_out += f"\033[31m{decoded}\033[0m({token}" + else: + example_out += f"\033[32m{decoded}\033[0m({token}" + + if position_id != None: + example_out += f"@{position_id})" + else: + example_out += ")" + + print(example_out) + + if dropped > 0: + print(f"Dropped {dropped} too-long samples") + print(f"Efficiency: {efficiency * 100:.1f}%") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--subset", type=str) + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--preprocessor", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=1000) + parser.add_argument("--num-proc", type=int) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--limit", type=int) + parser.add_argument("--chat", action="store_true") + parser.add_argument("--multiturn-only", action="store_true") + parser.add_argument("--pack-to-sequence-length", type=int) + parser.add_argument("--drop-larger-than", type=int) + parser.add_argument("--save-to-disk", type=str) + parser.add_argument("--show-example", action="store_true") + parser.add_argument("--limit", type=int, default=None) + args = parser.parse_args() + + main(args) diff --git a/scripts/preprocess_multimodal_data.py b/scripts/preprocess_multimodal_data.py new file mode 100644 index 0000000000..6ac5203b18 --- /dev/null +++ b/scripts/preprocess_multimodal_data.py @@ -0,0 +1,779 @@ +""" + +post of data processing script for multimodal data +that looks like this: +[ + { + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", + "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"}, + {"type": "text", "text": "Describe this image in detail."} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The image is a bee."} + ] + } + ] + } +] +""" +import argparse +import os +import shutil +import multiprocessing +import numpy as np +import pyarrow as pa +import pyarrow.dataset as pa_ds +import random +import json +import base64 +import uuid +from PIL import Image +import io +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + +from typing import List, Optional, Tuple +from torch.nn import functional as F +from torch.utils.data import Dataset +from tqdm import tqdm +from datasets import load_dataset, Dataset as DatasetsDataset +from transformers import AutoTokenizer + +from datetime import datetime, timedelta +import torch + +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from huggingface_hub import hf_hub_download +from transformers import Mistral3ForConditionalGeneration + + +LOCAL_IMAGE_DIR = "./images" + +SCHEMA = pa.schema( + [ + pa.field("inputs", pa.large_list(pa.int32())), + pa.field("labels", pa.large_list(pa.int32())), + pa.field("position_ids", pa.large_list(pa.int32())), + pa.field("sequence_lengths", pa.large_list(pa.int64())), + pa.field("images", pa.large_list(pa.string())), + ] +) + +DATASET_INFO = r"""{ + "citation": "", + "description": "", + "features": { + "inputs": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "labels": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "position_ids": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "sequence_lengths": { + "feature": { + "dtype": "int64", + "_type": "Value" + }, + "_type": "LargeList" + } + "images": { + "feature": { + "dtype": "str", + "_type": "Value" + }, + "_type": "LargeList" + } + }, + "homepage": "", + "license": "" +}""" + + +def process_packing_shard(shard, args, tokenizer_pad_id, rank, world_size): + packer = MultimodalPackedDataset( + shard, + max_seq_len=args.pack_to_sequence_length, + padding_idx=tokenizer_pad_id, + split_across_pack=not args.chat, + show_pbar=rank == 0, + ) + + + if args.save_to_disk: + # create a schema that uses int64 for list sizes + + example = ( + { + "inputs": packer.packs[0]["inputs"], + "labels": packer.packs[0]["labels"], + "position_ids": packer.packs[0]["position_ids"], + "sequence_lengths": packer.packs[0]["sequence_lengths"], + "images": packer.packs[0]["images"], + } + if len(packer.packs) > 0 + else None + ) + + oriented_data = { + "inputs": [pack["inputs"] for pack in packer.packs], + "labels": [pack["labels"] for pack in packer.packs], + "position_ids": [pack["position_ids"] for pack in packer.packs], + "sequence_lengths": [pack["sequence_lengths"] for pack in packer.packs], + "images": [pack["images"] for pack in packer.packs], + } + pa_table = pa.Table.from_pydict(oriented_data, schema=SCHEMA) + del oriented_data + + pa_ds.write_dataset( + pa_table, + os.path.join(args.save_to_disk, str(rank)), + format="arrow", + ) + + filename = f"data-{rank:05d}-of-{world_size:05d}.arrow" + + shutil.move( + os.path.join(args.save_to_disk, str(rank), "part-0.arrow"), + os.path.join(args.save_to_disk, filename), + ) + + os.rmdir(os.path.join(args.save_to_disk, str(rank))) + else: + filename = None + + return packer.total_tokens, packer.packed_tokens, packer.dropped, filename, example + + +# https://github.com/pytorch/torchtune/blob/9d91fe39f08661952da4180b9e7fb2eba5a7a5e7/torchtune/datasets/_packed.py +class MultimodalPackedDataset(Dataset): + """ + Performs greedy sample packing on a provided dataset. This is done as a single + preprocessing step before training begins. Shuffling is done outside of this + class on packed samples with a ``Sampler`` as part of the dataloader. Currently, + this only supports in-memory map-style datasets. + + The class loads, tokenizes, and packs examples on initialization - no tokenization is done during training. + + The general flow on initialization is: load tokenized sample -> add to buffer -> + when buffer is long enough, add to ``self.packs``. + + During training, returns self.packs[idx] as input, label, attention mask, and + position ids. The attention mask is a lower triangular block mask to prevent + samples from cross-attending within a pack. The position ids indicate the position + of each token relative to its sample within a pack. These are all padded to max + sequence length, so a batch-wise collator is not needed. + + A packed sample is made up of individual smaller sequence length samples jammed together + within ``max_seq_len``. For example, if max_seq_len is 6 and there are varied + length samples:: + + tokens = [ + [S1, S1, S1, S2, S2, pad], + [S3, S3, S4, S4, pad, pad], + ..., + ] + + To prevent cross-contamination, the following mask would be returned for the + first pack in the example:: + + mask = [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ] + + The position ids would be:: + + input_pos = [ + [0, 1, 2, 0, 1, 2], + [0, 1, 0, 1, 2, 3], + ..., + ] + + The identity matrix is used in the mask for pad tokens instead of a causal mask. + For position ids for pad tokens, we simply continue to increment from the previous + sample normally. + + Args: + ds (Dataset): dataset to sample pack. This should return a dictionary with field + "tokens" and "labels" containing the tokenized and label samples. + max_seq_len (int): Maximum number of tokens to pack + padding_idx (int): padding index for the tokenizer. Default is 0. + max_packs (Optional[int]): Maximum number of packs. Default is None, which will create as many + packs as possible. + split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``, + split the sample into the next pack, or move it entirely to the beginning of the next pack. + For pre-training, typically this is set to True for general text completion. For + fine-tuning, typically this is set to False to avoid truncating sentences in instruct + tuning. Default is False. + """ + + def __init__( + self, + ds: Dataset, + *, + max_seq_len: int, + padding_idx: int = 0, + max_packs: Optional[int] = None, + split_across_pack: bool = False, + group_size: int = 5000, + show_pbar=True, + ) -> None: + self.ds = ds + self.max_seq_len = max_seq_len + self.padding_idx = padding_idx + self.max_packs = max_packs + self.split_across_pack = split_across_pack + self.packs = [] + self.previous_sample_boundary: int = 0 + self.packed_tokens: int = 0 + self.total_tokens: int = 0 + self.dropped: int = 0 + self.show_pbar = show_pbar + self.group_size = group_size + if split_across_pack: + self._pack_greedy() + else: + self._pack_ffd() + + def _get_empty_pack(self): + + return { + "inputs": np.empty(0, dtype=np.int32), + "labels": np.empty(0, dtype=np.int32), + "position_ids": np.empty(0, dtype=np.int32), + "sequence_lengths": [], + "images": [], + } + + def _pack_ffd(self) -> None: + ds_iterator = iter(self.ds) + finished_iterating = False + + pbar = ( + tqdm( + total=len(self.ds), + desc="Packing dataset (FFD)", + dynamic_ncols=True, + ) + if self.show_pbar + else None + ) + + while not finished_iterating: + # 1. Fetch a large group of samples into memory. + group = [] + try: + for _ in range(self.group_size): + sample = next(ds_iterator) + seq_len = len(sample["inputs"]) + + + if seq_len > self.max_seq_len: + self.dropped += 1 + continue + # Store sample and its length for sorting + group.append({"sample": sample, "seq_len": seq_len}) + except StopIteration: + finished_iterating = True + + if not group: + break + + + # 2. Sort the group by length in descending order (the "Decreasing" part of FFD). + group.sort(key=lambda x: x["seq_len"], reverse=True) + + # 3. Pack this group using the "First-Fit" heuristic. + # Each bin holds the samples it contains and its remaining space. + bins = [] # List of {"samples": [], "remaining_space": max_seq_len} + + for item in group: + placed = False + # Try to place the item in the first available bin. + for bin in bins: + if bin["remaining_space"] >= item["seq_len"]: + bin["samples"].append(item["sample"]) + bin["remaining_space"] -= item["seq_len"] + placed = True + break + + # If no existing bin could accommodate the item, create a new one. + if not placed: + bins.append( + { + "samples": [item["sample"]], + "remaining_space": self.max_seq_len - item["seq_len"], + } + ) + + + # 4. Convert the completed bins from this group into final, padded packs. + for bin_info in bins: + if self._should_stop_packing(): + break + + current_pack = self._get_empty_pack() + for sample in bin_info["samples"]: + tokens = np.array(sample["inputs"], dtype=np.int32) + labels = np.array(sample["labels"], dtype=np.int32) + images = sample["images"] + seq_len = len(tokens) + + + current_pack["inputs"] = np.concatenate( + (current_pack["inputs"], tokens) + ) + current_pack["labels"] = np.concatenate( + (current_pack["labels"], labels) + ) + current_pack["position_ids"] = np.concatenate( + ( + current_pack["position_ids"], + np.arange(seq_len, dtype=np.int32), + ) + ) + current_pack["sequence_lengths"].append(seq_len) + current_pack["images"].append(images) + + + self._add_pack(current_pack) + + + if pbar: + pbar.update(len(group)) + + if self._should_stop_packing(): + # Ensure the outer loop breaks if max_packs is reached. + break + + if pbar: + # Manually set pbar to total to show 100% at the end + pbar.n = pbar.total + pbar.refresh() + pbar.close() + + def _pack_greedy(self) -> None: + """Iterate through the dataset. Use a buffer to hold samples until max_seq_len, + then append the buffer to self.packs as a single "packed" sample. Continue + until max_packs or end of dataset.""" + + current_pack = self._get_empty_pack() + + pbar = ( + tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True) + if self.show_pbar + else None + ) + + for sample in self.ds: + tokens = np.array(sample["inputs"], dtype=np.int32) + labels = np.array(sample["labels"], dtype=np.int32) + + if seq_len > self.max_seq_len and not self.split_across_pack: + # print( + # f"Dropping sample that is too long ({seq_len} > {self.max_seq_len})" + # ) + self.dropped += 1 + continue + + current_pack["inputs"] = np.concatenate((current_pack["inputs"], tokens)) + current_pack["labels"] = np.concatenate((current_pack["labels"], labels)) + + position_ids = np.arange(seq_len, dtype=np.int32) + current_pack["position_ids"] = np.concatenate( + (current_pack["position_ids"], position_ids) + ) + + current_pack["sequence_lengths"] += [seq_len] + + while ( + len(current_pack["inputs"]) > self.max_seq_len + and not self._should_stop_packing() + ): + current_pack = self._split_and_add_pack(current_pack) + + if pbar: + pbar.update() + + self.previous_sample_boundary = len(current_pack["inputs"]) + + if self._should_stop_packing(): + break + + if len(current_pack["inputs"]) > 0 and ( + self.max_packs is None or len(self.packs) < self.max_packs + ): + self._add_pack(current_pack) + + def _should_stop_packing(self) -> bool: + """If max packs is set, stop packing when we reach that number.""" + + if self.max_packs is not None and len(self.packs) == self.max_packs: + return True + return False + + def _split_and_add_pack(self, current_pack): + """Splits the current pack at the boundary, processes it, adds it to ``self.packs`` and + returns the start of the next pack.""" + + if self.split_across_pack: + boundary = self.max_seq_len + # The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len`` + leftover_seq_len = self.max_seq_len - sum(current_pack["seq_lens"][:-1]) + seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else [] + else: + boundary = self.previous_sample_boundary + # If we aren't splitting across packs, we leave out the last sample b/c + # it will go into the next pack + seq_len_padding = [] + + pack = { + "inputs": current_pack["inputs"][:boundary], + "labels": current_pack["labels"][:boundary], + "position_ids": current_pack["position_ids"][:boundary], + "sequence_lengths": current_pack["sequence_lengths"][:-1] + seq_len_padding, + } + + self._add_pack(pack) + + # Return the length of the first sample in next pack if we are splitting across packs, + # otherwise return the length of the last sample in the current pack + next_seq_len = ( + len(current_pack["inputs"][boundary:]) + if self.split_across_pack + else current_pack["sequence_lengths"][-1] + ) + + return { + "inputs": current_pack["inputs"][boundary:], + "labels": current_pack["labels"][boundary:], + "position_ids": current_pack["position_ids"][boundary:], + "sequence_lengths": [next_seq_len], + } + + def _add_pack(self, pack) -> None: + """Processes, pads and adds a pack to ``self.packs``.""" + pack = self._pad_pack(pack, padding_idx=self.padding_idx) + self.packs.append(pack) + + def _pad_pack(self, pack, padding_idx: int): + """Pads a pack to ``self.max_seq_len``.""" + num_tokens = len(pack["inputs"]) + num_padding_tokens = self.max_seq_len - num_tokens + + self.packed_tokens += num_tokens + self.total_tokens += self.max_seq_len + + padded_inputs = np.pad( + pack["inputs"], (0, num_padding_tokens), constant_values=self.padding_idx + ) + padded_labels = np.pad( + pack["labels"], (0, num_padding_tokens), constant_values=-100 + ) + + if num_padding_tokens > 0: + # don't care much about padded position_ids, but create them for consistency + start_pos = int(pack["position_ids"][-1] + 1) if num_tokens > 0 else 0 + pad_positions = np.arange( + start_pos, start_pos + num_padding_tokens, dtype=np.int32 + ) + padded_position_ids = np.concatenate((pack["position_ids"], pad_positions)) + else: + padded_position_ids = pack["position_ids"] + + padded_seq_lens = pack["sequence_lengths"] + if num_padding_tokens > 0: + padded_seq_lens.append(num_padding_tokens) + + return { + "inputs": padded_inputs, + "labels": padded_labels, + "position_ids": padded_position_ids, + "sequence_lengths": padded_seq_lens, + "images": pack["images"], + } + + def __len__(self) -> int: + return len(self.packs) + + def __getitem__(self, idx: int) -> dict[str, np.ndarray]: + return self.packs[idx] + + +def main(args): + + from datasets import load_dataset + + #dataset = load_dataset('json', data_files='/home/artem_nous/cambrian_set/output2.json')['train'].select(range(100)) + + if 'json' in args.dataset: + dataset = load_dataset('json', data_files=args.dataset)['train'] + if args.limit is not None: + dataset = dataset.select(range(args.limit)) + else: + dataset = load_dataset(args.dataset, name=args.subset, split=args.split) + + + def remove_none_recursively(obj): + if isinstance(obj, dict): + return {k: remove_none_recursively(v) for k, v in obj.items() if v is not None} + elif isinstance(obj, list): + return [remove_none_recursively(item) for item in obj] + else: + return obj + + from transformers import AutoProcessor + tokenizer = AutoProcessor.from_pretrained(args.preprocessor, use_fast=True) + + + def _tokenize_chat_multimodal(sample): + inputs = [] + labels = [] + images = [] + + for conversation in sample["messages"]: + + image = None + conversation = remove_none_recursively(conversation) + + for message in conversation: + + keys = list(message.keys()) + + for item in message['content']: + if 'base64' in item.keys(): + # save image in local folder as PIL image with uuid + # Decode base64 image data + image_data = base64.b64decode(item['base64']) + image = Image.open(io.BytesIO(image_data)) + + # Generate UUID4 filename + image_filename = f"{uuid.uuid4()}.jpg" + image_path = os.path.join(LOCAL_IMAGE_DIR, image_filename) + + # Ensure directory exists + os.makedirs(LOCAL_IMAGE_DIR, exist_ok=True) + + # Save image as JPG + image.save(image_path, 'JPEG') + + # remove base64 key, set type image + item.pop('base64') + item['type'] = 'image' + item['path'] = image_path + + image = image_path + + + if "from" in keys and "value" in keys: + # sharegpt format + message_from = message.pop("from") + if message_from == "gpt": + message["role"] = "assistant" + elif message_from == "human": + message["role"] = "user" + else: + message["role"] = message_from + + message["content"] = message.pop("value") + elif "role" in keys and "content" in keys: + pass + else: + raise RuntimeError(f"Unknown chat format, keys are {keys}") + + + tokenized = tokenizer.apply_chat_template(conversation, tokenize=True, return_dict=True, return_tensors="pt") + + tokens = tokenized["input_ids"][0] #tokenizer.apply_chat_template(conversation, tokenize=True) + + + current_len = 0 + label = [] + for i in range(len(conversation)): + if i + 1 == len(conversation): + next_tokens = tokenizer.apply_chat_template(conversation, + tokenize=True, return_dict=True, return_tensors="pt")["input_ids"][0][current_len:] + else: + if "assistant" == conversation[i + 1]["role"]: + next_tokens = tokenizer.apply_chat_template(conversation[: i + 1], + add_generation_prompt=True, tokenize=True, return_dict=True)["input_ids"][0][current_len:] + else: + next_tokens = tokenizer.apply_chat_template(conversation[: i + 1], + tokenize=True, return_dict=True)["input_ids"][0][current_len:] + #next_tokens = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=conversation[: i + 1])).tokens[current_len:] + + if conversation[i]["role"] == "assistant": + label.extend(next_tokens) + else: + label.extend([-100] * len(next_tokens)) + + current_len += len(next_tokens) + + inputs.append(tokens) + labels.append(label) + images.append(image) + + + return { + "inputs": inputs, + "labels": labels, + "images": images, + } + + def _tokenize_mistral_format(sample): + messages = sample["messages"] + cleaned_messages = remove_none_recursively(messages) + tokenized = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=cleaned_messages)) + return tokenized.__dict__ + + dataset = dataset.shuffle(args.seed) + + original_column_names = list(dataset.features.keys()) + + dataset = dataset.map( + _tokenize_chat_multimodal, + batched=True, + #batch_size=args.batch_size, + ) + + dataset = dataset.remove_columns(original_column_names) + #print(dataset[0]['images']) + + efficiency = 1.0 + dropped = 0 + if args.pack_to_sequence_length: + num_shards = 32 # args.num_proc + shards = [ + dataset.shard(num_shards=num_shards, index=i) for i in range(num_shards) + ] + + + with multiprocessing.Pool(processes=num_shards) as pool: + process_args = [ + (shard, args, tokenizer.tokenizer.pad_token_id, index, num_shards) + for index, shard in enumerate(shards) + ] + + results = pool.starmap(process_packing_shard, process_args) + + examples = [] + filenames = [] + total_tokens = 0 + packed_tokens = 0 + + for total, packed, dropped_, filename, example in tqdm(results): + if example: + examples.append(example) + if filename: + filenames.append(filename) + total_tokens += total + packed_tokens += packed + dropped += dropped_ + + if total_tokens > 0: + efficiency = packed_tokens / total_tokens + + example = examples[0] + + if args.save_to_disk: + with open(os.path.join(args.save_to_disk, "dataset_info.json"), "wb") as f: + f.write(DATASET_INFO.encode()) + + # verify we can open and do any conversion needed + dataset = load_dataset(args.save_to_disk, num_proc=args.num_proc) + + else: + if args.drop_larger_than: + len_before = len(dataset) + dataset = dataset.filter( + lambda x: len(x["inputs"]) <= args.drop_larger_than + ) + dropped = len_before - len(dataset) + + if args.save_to_disk: + print(f"Saving to {args.save_to_disk}") + dataset.save_to_disk(args.save_to_disk) + + example = dataset[0] + + if args.show_example: + inputs = example["inputs"] + labels = example["labels"] if "labels" in example else None + position_ids = example["position_ids"] if "position_ids" in example else None + + example_out = "" + for i in range(0, len(inputs)): + token = inputs[i] + label = labels[i] if labels is not None else token + position_id = position_ids[i] if position_ids is not None else None + + decoded = tokenizer.decode(token) + + if label == -100: + example_out += f"\033[31m{decoded}\033[0m({token}" + else: + example_out += f"\033[32m{decoded}\033[0m({token}" + + if position_id != None: + example_out += f"@{position_id})" + else: + example_out += ")" + + print(example_out) + + if dropped > 0: + print(f"Dropped {dropped} too-long samples") + print(f"Efficiency: {efficiency * 100:.1f}%") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--subset", type=str) + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--preprocessor", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=1000) + parser.add_argument("--num-proc", type=int) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--limit", type=int, required=False) + parser.add_argument("--chat", action="store_true") + parser.add_argument("--multiturn-only", action="store_true") + parser.add_argument("--pack-to-sequence-length", type=int) + parser.add_argument("--drop-larger-than", type=int) + parser.add_argument("--save-to-disk", type=str) + parser.add_argument("--show-example", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/scripts/preprocess_multimodal_old.py b/scripts/preprocess_multimodal_old.py new file mode 100644 index 0000000000..44a9548665 --- /dev/null +++ b/scripts/preprocess_multimodal_old.py @@ -0,0 +1,800 @@ +""" + +post of data processing script for multimodal data +that looks like this: +[ + { + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + }, + { + "role": "user", + "content": [ + {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"}, + {"type": "text", "text": "Describe this image in detail."} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "The image is a bee."} + ] + } + ] + } +] +""" +import argparse +import os +import shutil +import multiprocessing +import numpy as np +import pyarrow as pa +import pyarrow.dataset as pa_ds +import random +import json +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + +from typing import List, Optional, Tuple +from torch.nn import functional as F +from torch.utils.data import Dataset +from tqdm import tqdm +from datasets import load_dataset, Dataset as DatasetsDataset +from transformers import AutoTokenizer + +from datetime import datetime, timedelta +import torch + +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer +from huggingface_hub import hf_hub_download +from transformers import Mistral3ForConditionalGeneration + + +SCHEMA = pa.schema( + [ + pa.field("inputs", pa.large_list(pa.int32())), + pa.field("labels", pa.large_list(pa.int32())), + pa.field("position_ids", pa.large_list(pa.int32())), + pa.field("sequence_lengths", pa.large_list(pa.int64())), + ] +) + +DATASET_INFO = r"""{ + "citation": "", + "description": "", + "features": { + "inputs": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "labels": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "position_ids": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "LargeList" + }, + "sequence_lengths": { + "feature": { + "dtype": "int64", + "_type": "Value" + }, + "_type": "LargeList" + } + "pixel_values": { + "feature": { + "dtype": "bfloat16", + "_type": "Value" + }, + "_type": "LargeList" + } + "image_sizes": { + "feature": { + "dtype": "bfloat16", + "_type": "Value" + }, + "_type": "LargeList" + } + }, + "homepage": "", + "license": "" +}""" + + +def process_packing_shard(shard, args, tokenizer_pad_id, rank, world_size): + packer = MultimodalPackedDataset( + shard, + max_seq_len=args.pack_to_sequence_length, + padding_idx=tokenizer_pad_id, + split_across_pack=not args.chat, + show_pbar=rank == 0, + ) + + + if args.save_to_disk: + # create a schema that uses int64 for list sizes + + example = ( + { + "inputs": packer.packs[0]["inputs"], + "labels": packer.packs[0]["labels"], + "position_ids": packer.packs[0]["position_ids"], + "sequence_lengths": packer.packs[0]["sequence_lengths"], + "pixel_values": packer.packs[0]["pixel_values"], + "image_sizes": packer.packs[0]["image_sizes"], + } + if len(packer.packs) > 0 + else None + ) + + oriented_data = { + "inputs": [pack["inputs"] for pack in packer.packs], + "labels": [pack["labels"] for pack in packer.packs], + "position_ids": [pack["position_ids"] for pack in packer.packs], + "sequence_lengths": [pack["sequence_lengths"] for pack in packer.packs], + "pixel_values": [pack["pixel_values"] for pack in packer.packs], + "image_sizes": [pack["image_sizes"] for pack in packer.packs], + } + pa_table = pa.Table.from_pydict(oriented_data, schema=SCHEMA) + del oriented_data + + pa_ds.write_dataset( + pa_table, + os.path.join(args.save_to_disk, str(rank)), + format="arrow", + ) + + filename = f"data-{rank:05d}-of-{world_size:05d}.arrow" + + shutil.move( + os.path.join(args.save_to_disk, str(rank), "part-0.arrow"), + os.path.join(args.save_to_disk, filename), + ) + + os.rmdir(os.path.join(args.save_to_disk, str(rank))) + else: + filename = None + + return packer.total_tokens, packer.packed_tokens, packer.dropped, filename, example + + +# https://github.com/pytorch/torchtune/blob/9d91fe39f08661952da4180b9e7fb2eba5a7a5e7/torchtune/datasets/_packed.py +class MultimodalPackedDataset(Dataset): + """ + Performs greedy sample packing on a provided dataset. This is done as a single + preprocessing step before training begins. Shuffling is done outside of this + class on packed samples with a ``Sampler`` as part of the dataloader. Currently, + this only supports in-memory map-style datasets. + + The class loads, tokenizes, and packs examples on initialization - no tokenization is done during training. + + The general flow on initialization is: load tokenized sample -> add to buffer -> + when buffer is long enough, add to ``self.packs``. + + During training, returns self.packs[idx] as input, label, attention mask, and + position ids. The attention mask is a lower triangular block mask to prevent + samples from cross-attending within a pack. The position ids indicate the position + of each token relative to its sample within a pack. These are all padded to max + sequence length, so a batch-wise collator is not needed. + + A packed sample is made up of individual smaller sequence length samples jammed together + within ``max_seq_len``. For example, if max_seq_len is 6 and there are varied + length samples:: + + tokens = [ + [S1, S1, S1, S2, S2, pad], + [S3, S3, S4, S4, pad, pad], + ..., + ] + + To prevent cross-contamination, the following mask would be returned for the + first pack in the example:: + + mask = [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ] + + The position ids would be:: + + input_pos = [ + [0, 1, 2, 0, 1, 2], + [0, 1, 0, 1, 2, 3], + ..., + ] + + The identity matrix is used in the mask for pad tokens instead of a causal mask. + For position ids for pad tokens, we simply continue to increment from the previous + sample normally. + + Args: + ds (Dataset): dataset to sample pack. This should return a dictionary with field + "tokens" and "labels" containing the tokenized and label samples. + max_seq_len (int): Maximum number of tokens to pack + padding_idx (int): padding index for the tokenizer. Default is 0. + max_packs (Optional[int]): Maximum number of packs. Default is None, which will create as many + packs as possible. + split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``, + split the sample into the next pack, or move it entirely to the beginning of the next pack. + For pre-training, typically this is set to True for general text completion. For + fine-tuning, typically this is set to False to avoid truncating sentences in instruct + tuning. Default is False. + """ + + def __init__( + self, + ds: Dataset, + *, + max_seq_len: int, + padding_idx: int = 0, + max_packs: Optional[int] = None, + split_across_pack: bool = False, + group_size: int = 5000, + show_pbar=True, + ) -> None: + self.ds = ds + self.max_seq_len = max_seq_len + self.padding_idx = padding_idx + self.max_packs = max_packs + self.split_across_pack = split_across_pack + self.packs = [] + self.previous_sample_boundary: int = 0 + self.packed_tokens: int = 0 + self.total_tokens: int = 0 + self.dropped: int = 0 + self.show_pbar = show_pbar + self.group_size = group_size + if split_across_pack: + self._pack_greedy() + else: + self._pack_ffd() + + def _get_empty_pack(self): + + return { + "inputs": np.empty(0, dtype=np.int32), + "labels": np.empty(0, dtype=np.int32), + "position_ids": np.empty(0, dtype=np.int32), + "sequence_lengths": [], + "pixel_values": [], + "image_sizes": [], + } + + def _pack_ffd(self) -> None: + ds_iterator = iter(self.ds) + finished_iterating = False + + pbar = ( + tqdm( + total=len(self.ds), + desc="Packing dataset (FFD)", + dynamic_ncols=True, + ) + if self.show_pbar + else None + ) + + while not finished_iterating: + # 1. Fetch a large group of samples into memory. + group = [] + try: + for _ in range(self.group_size): + sample = next(ds_iterator) + seq_len = len(sample["inputs"]) + + + if seq_len > self.max_seq_len: + self.dropped += 1 + continue + # Store sample and its length for sorting + group.append({"sample": sample, "seq_len": seq_len}) + except StopIteration: + finished_iterating = True + + if not group: + break + + + print("done") + + # 2. Sort the group by length in descending order (the "Decreasing" part of FFD). + group.sort(key=lambda x: x["seq_len"], reverse=True) + print("sorted") + + # 3. Pack this group using the "First-Fit" heuristic. + # Each bin holds the samples it contains and its remaining space. + bins = [] # List of {"samples": [], "remaining_space": max_seq_len} + + for item in group: + placed = False + # Try to place the item in the first available bin. + for bin in bins: + if bin["remaining_space"] >= item["seq_len"]: + bin["samples"].append(item["sample"]) + bin["remaining_space"] -= item["seq_len"] + placed = True + break + + # If no existing bin could accommodate the item, create a new one. + if not placed: + bins.append( + { + "samples": [item["sample"]], + "remaining_space": self.max_seq_len - item["seq_len"], + } + ) + + print("bins") + + # 4. Convert the completed bins from this group into final, padded packs. + for bin_info in bins: + if self._should_stop_packing(): + break + + current_pack = self._get_empty_pack() + for sample in bin_info["samples"]: + tokens = np.array(sample["inputs"], dtype=np.int32) + labels = np.array(sample["labels"], dtype=np.int32) + images = sample["images"] + seq_len = len(tokens) + + + + pixel_values = torch.tensor(images).to(dtype=torch.bfloat16) + + + image_sizes = torch.tensor([[pixel_values.shape[-2], pixel_values.shape[-1]]] * len(images)) + + + current_pack["inputs"] = np.concatenate( + (current_pack["inputs"], tokens) + ) + current_pack["labels"] = np.concatenate( + (current_pack["labels"], labels) + ) + current_pack["position_ids"] = np.concatenate( + ( + current_pack["position_ids"], + np.arange(seq_len, dtype=np.int32), + ) + ) + current_pack["sequence_lengths"].append(seq_len) + current_pack["pixel_values"].extend(pixel_values) + current_pack["image_sizes"].extend(image_sizes) + + self._add_pack(current_pack) + + print("done") + + if pbar: + pbar.update(len(group)) + + if self._should_stop_packing(): + # Ensure the outer loop breaks if max_packs is reached. + break + + if pbar: + # Manually set pbar to total to show 100% at the end + pbar.n = pbar.total + pbar.refresh() + pbar.close() + + def _pack_greedy(self) -> None: + """Iterate through the dataset. Use a buffer to hold samples until max_seq_len, + then append the buffer to self.packs as a single "packed" sample. Continue + until max_packs or end of dataset.""" + + current_pack = self._get_empty_pack() + + pbar = ( + tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True) + if self.show_pbar + else None + ) + + for sample in self.ds: + tokens = np.array(sample["inputs"], dtype=np.int32) + labels = np.array(sample["labels"], dtype=np.int32) + + if seq_len > self.max_seq_len and not self.split_across_pack: + # print( + # f"Dropping sample that is too long ({seq_len} > {self.max_seq_len})" + # ) + self.dropped += 1 + continue + + current_pack["inputs"] = np.concatenate((current_pack["inputs"], tokens)) + current_pack["labels"] = np.concatenate((current_pack["labels"], labels)) + + position_ids = np.arange(seq_len, dtype=np.int32) + current_pack["position_ids"] = np.concatenate( + (current_pack["position_ids"], position_ids) + ) + + current_pack["sequence_lengths"] += [seq_len] + + while ( + len(current_pack["inputs"]) > self.max_seq_len + and not self._should_stop_packing() + ): + current_pack = self._split_and_add_pack(current_pack) + + if pbar: + pbar.update() + + self.previous_sample_boundary = len(current_pack["inputs"]) + + if self._should_stop_packing(): + break + + if len(current_pack["inputs"]) > 0 and ( + self.max_packs is None or len(self.packs) < self.max_packs + ): + self._add_pack(current_pack) + + def _should_stop_packing(self) -> bool: + """If max packs is set, stop packing when we reach that number.""" + + if self.max_packs is not None and len(self.packs) == self.max_packs: + return True + return False + + def _split_and_add_pack(self, current_pack): + """Splits the current pack at the boundary, processes it, adds it to ``self.packs`` and + returns the start of the next pack.""" + + if self.split_across_pack: + boundary = self.max_seq_len + # The last elem in ``seq_lens`` ensures that ``sum(seq_lens) == self.max_seq_len`` + leftover_seq_len = self.max_seq_len - sum(current_pack["seq_lens"][:-1]) + seq_len_padding = [leftover_seq_len] if leftover_seq_len > 0 else [] + else: + boundary = self.previous_sample_boundary + # If we aren't splitting across packs, we leave out the last sample b/c + # it will go into the next pack + seq_len_padding = [] + + pack = { + "inputs": current_pack["inputs"][:boundary], + "labels": current_pack["labels"][:boundary], + "position_ids": current_pack["position_ids"][:boundary], + "sequence_lengths": current_pack["sequence_lengths"][:-1] + seq_len_padding, + } + + self._add_pack(pack) + + # Return the length of the first sample in next pack if we are splitting across packs, + # otherwise return the length of the last sample in the current pack + next_seq_len = ( + len(current_pack["inputs"][boundary:]) + if self.split_across_pack + else current_pack["sequence_lengths"][-1] + ) + + return { + "inputs": current_pack["inputs"][boundary:], + "labels": current_pack["labels"][boundary:], + "position_ids": current_pack["position_ids"][boundary:], + "sequence_lengths": [next_seq_len], + } + + def _add_pack(self, pack) -> None: + """Processes, pads and adds a pack to ``self.packs``.""" + pack = self._pad_pack(pack, padding_idx=self.padding_idx) + self.packs.append(pack) + + def _pad_pack(self, pack, padding_idx: int): + """Pads a pack to ``self.max_seq_len``.""" + num_tokens = len(pack["inputs"]) + num_padding_tokens = self.max_seq_len - num_tokens + + self.packed_tokens += num_tokens + self.total_tokens += self.max_seq_len + + padded_inputs = np.pad( + pack["inputs"], (0, num_padding_tokens), constant_values=self.padding_idx + ) + padded_labels = np.pad( + pack["labels"], (0, num_padding_tokens), constant_values=-100 + ) + + if num_padding_tokens > 0: + # don't care much about padded position_ids, but create them for consistency + start_pos = int(pack["position_ids"][-1] + 1) if num_tokens > 0 else 0 + pad_positions = np.arange( + start_pos, start_pos + num_padding_tokens, dtype=np.int32 + ) + padded_position_ids = np.concatenate((pack["position_ids"], pad_positions)) + else: + padded_position_ids = pack["position_ids"] + + padded_seq_lens = pack["sequence_lengths"] + if num_padding_tokens > 0: + padded_seq_lens.append(num_padding_tokens) + + return { + "inputs": padded_inputs, + "labels": padded_labels, + "position_ids": padded_position_ids, + "sequence_lengths": padded_seq_lens, + "pixel_values": pack["pixel_values"], + "image_sizes": pack["image_sizes"], + } + + def __len__(self) -> int: + return len(self.packs) + + def __getitem__(self, idx: int) -> dict[str, np.ndarray]: + return self.packs[idx] + + +def main(args): + + from datasets import load_dataset + """ + + SYSTEM_PROMPT = "You are a helpful assistant." + image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" + + ds = [{"messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, +]}] + + ds = ds * 10 + + with open('dataset.json', 'w') as f: + json.dump(ds, f) + """ + + dataset = load_dataset('json', data_files='/home/artem_nous/cambrian_set/output2.json')['train'].select(range(10)) + + #dataset = load_dataset(args.dataset, name=args.subset, split=args.split) + + def remove_none_recursively(obj): + if isinstance(obj, dict): + return {k: remove_none_recursively(v) for k, v in obj.items() if v is not None} + elif isinstance(obj, list): + return [remove_none_recursively(item) for item in obj] + else: + return obj + + #dataset = load_dataset(args.dataset, name=args.subset, split=args.split) + #tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + #tokenizer = MistralTokenizer.from_hf_hub(args.preprocessor) + from transformers import AutoProcessor + tokenizer = AutoProcessor.from_pretrained(args.preprocessor, use_fast=True) + + + def _tokenize_chat_multimodal(sample): + inputs = [] + labels = [] + images = [] + + for conversation in sample["messages"]: + for message in conversation: + + keys = list(message.keys()) + + """ + for content in message['content']: + if content.get('base64'): + content['type'] = 'image_url' + content['image_url'] = content['base64'] + content['image_url'] = f"data:image/jpeg;base64,{content['base64']}" + #print(content['base64'][:100]) + # remove base64 + content.pop('base64') + """ + + if "from" in keys and "value" in keys: + # sharegpt format + message_from = message.pop("from") + if message_from == "gpt": + message["role"] = "assistant" + elif message_from == "human": + message["role"] = "user" + else: + message["role"] = message_from + + message["content"] = message.pop("value") + elif "role" in keys and "content" in keys: + pass + else: + raise RuntimeError(f"Unknown chat format, keys are {keys}") + + + conversation = remove_none_recursively(conversation) + + #print(conversation) + + #tokenized = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=conversation)) + tokenized = tokenizer.apply_chat_template(conversation, tokenize=True, return_dict=True, return_tensors="pt") + tokens = tokenized["input_ids"][0] #tokenizer.apply_chat_template(conversation, tokenize=True) + image = tokenized["pixel_values"] + + current_len = 0 + label = [] + for i in range(len(conversation)): + if i + 1 == len(conversation): + next_tokens = tokenizer.apply_chat_template(conversation, + tokenize=True, return_dict=True, return_tensors="pt")["input_ids"][0][current_len:] + else: + if "assistant" == conversation[i + 1]["role"]: + next_tokens = tokenizer.apply_chat_template(conversation[: i + 1], + add_generation_prompt=True, tokenize=True, return_dict=True)["input_ids"][0][current_len:] + else: + next_tokens = tokenizer.apply_chat_template(conversation[: i + 1], + tokenize=True, return_dict=True)["input_ids"][0][current_len:] + #next_tokens = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=conversation[: i + 1])).tokens[current_len:] + + if conversation[i]["role"] == "assistant": + label.extend(next_tokens) + else: + label.extend([-100] * len(next_tokens)) + + current_len += len(next_tokens) + + inputs.append(tokens) + labels.append(label) + images.append(image) + + return { + "inputs": inputs, + "labels": labels, + "images": images, + } + + def _tokenize_mistral_format(sample): + messages = sample["messages"] + cleaned_messages = remove_none_recursively(messages) + tokenized = tokenizer.encode_chat_completion(ChatCompletionRequest(messages=cleaned_messages)) + return tokenized.__dict__ + + dataset = dataset.shuffle(args.seed) + + original_column_names = list(dataset.features.keys()) + + dataset = dataset.map( + _tokenize_chat_multimodal, + batched=True, + #batch_size=args.batch_size, + ) + + dataset = dataset.remove_columns(original_column_names) + print(dataset[0]['images']) + + efficiency = 1.0 + dropped = 0 + if args.pack_to_sequence_length: + num_shards = 1 # args.num_proc + shards = [ + dataset.shard(num_shards=num_shards, index=i) for i in range(num_shards) + ] + + + with multiprocessing.Pool(processes=num_shards) as pool: + process_args = [ + (shard, args, tokenizer.tokenizer.pad_token_id, index, num_shards) + for index, shard in enumerate(shards) + ] + + results = pool.starmap(process_packing_shard, process_args) + + examples = [] + filenames = [] + total_tokens = 0 + packed_tokens = 0 + + for total, packed, dropped_, filename, example in tqdm(results): + if example: + examples.append(example) + if filename: + filenames.append(filename) + total_tokens += total + packed_tokens += packed + dropped += dropped_ + + if total_tokens > 0: + efficiency = packed_tokens / total_tokens + + example = examples[0] + + if args.save_to_disk: + with open(os.path.join(args.save_to_disk, "dataset_info.json"), "wb") as f: + f.write(DATASET_INFO.encode()) + + # verify we can open and do any conversion needed + dataset = load_dataset(args.save_to_disk, num_proc=args.num_proc) + + else: + if args.drop_larger_than: + len_before = len(dataset) + dataset = dataset.filter( + lambda x: len(x["inputs"]) <= args.drop_larger_than + ) + dropped = len_before - len(dataset) + + if args.save_to_disk: + print(f"Saving to {args.save_to_disk}") + dataset.save_to_disk(args.save_to_disk) + + example = dataset[0] + + if args.show_example: + inputs = example["inputs"] + labels = example["labels"] if "labels" in example else None + position_ids = example["position_ids"] if "position_ids" in example else None + + example_out = "" + for i in range(0, len(inputs)): + token = inputs[i] + label = labels[i] if labels is not None else token + position_id = position_ids[i] if position_ids is not None else None + + decoded = tokenizer.decode(token) + + if label == -100: + example_out += f"\033[31m{decoded}\033[0m({token}" + else: + example_out += f"\033[32m{decoded}\033[0m({token}" + + if position_id != None: + example_out += f"@{position_id})" + else: + example_out += ")" + + print(example_out) + + if dropped > 0: + print(f"Dropped {dropped} too-long samples") + print(f"Efficiency: {efficiency * 100:.1f}%") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--subset", type=str) + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--preprocessor", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=1000) + parser.add_argument("--num-proc", type=int) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--limit", type=int) + parser.add_argument("--chat", action="store_true") + parser.add_argument("--multiturn-only", action="store_true") + parser.add_argument("--pack-to-sequence-length", type=int) + parser.add_argument("--drop-larger-than", type=int) + parser.add_argument("--save-to-disk", type=str) + parser.add_argument("--show-example", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/scripts/save_hermes4_subset_locally.py b/scripts/save_hermes4_subset_locally.py new file mode 100644 index 0000000000..2b445d64f2 --- /dev/null +++ b/scripts/save_hermes4_subset_locally.py @@ -0,0 +1,66 @@ +import re +from datasets import load_dataset +from datasets.utils.info_utils import VerificationMode + +# Load your dataset (replace 'your_dataset_name' with the actual Hugging Face dataset name or path) +ds = load_dataset("NousResearch/Hermes-4-v4-Final-Nonreasoning-Only", verification_mode=VerificationMode.NO_CHECKS, split="train[:20000]") + +def process_conversation(row): + #image_path = row["image"] + original_conv = row["conversations"] + + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ] + } + ] + + # Assuming the conversation alternates starting with human, and image is only in the first human message + for i, turn in enumerate(original_conv): + role = "user" if turn["from"] == "human" else "assistant" + value = turn["value"] + """ + + if role == "user" and i == 0 and "" in value: + # Split the value around + parts = re.split(r'(\n?\n?)', value) + content = [] + for part in parts: + if re.match(r'\n?\n?', part): + content.append({"type": "image", "path": image_path}) + elif part.strip(): + content.append({"type": "text", "text": part.strip()}) + else: + """ + + + content = [{"type": "text", "text": value.strip(), "path": None}] + + messages.append({ + "role": role, + "content": content + }) + + # The format has an outer list with a dict containing "messages" + #new_conv = [{"messages": messages}] + + return {"conversations": messages} + +# Apply the transformation and keep only the new "conversations" column +new_ds = ds.map( + process_conversation, + remove_columns=ds.column_names +) + +# Optionally, push to Hugging Face or save +# new_ds.push_to_hub("new_dataset_name") +# or + + +print(new_ds[200]['conversations']) +new_ds.save_to_disk("H4_Subset") + +#print(new_ds[0]) diff --git a/torchtitan/datasets/dataloader.py b/torchtitan/datasets/dataloader.py index 4924cc1c42..cf95940b2d 100644 --- a/torchtitan/datasets/dataloader.py +++ b/torchtitan/datasets/dataloader.py @@ -5,10 +5,12 @@ from .hf_datasets import build_hf_dataloader from .preprocessed import build_preprocessed_dataloader +from .preprocessed import build_preprocessed_multimodal_dataloader DATALOADERS = { "huggingface": build_hf_dataloader, "preprocessed": build_preprocessed_dataloader, + "preprocessed_multimodal": build_preprocessed_multimodal_dataloader, } diff --git a/torchtitan/datasets/preprocessed.py b/torchtitan/datasets/preprocessed.py index 1f2396893a..dd7c4432d9 100644 --- a/torchtitan/datasets/preprocessed.py +++ b/torchtitan/datasets/preprocessed.py @@ -146,6 +146,136 @@ def collate_fn(batch): return args, labels_tensor +class PreprocessedMultimodalDataset(IterableDataset, Stateful): + def __init__( + self, + dataset_name: str, + dataset_path: str | None, + tokenizer: Tokenizer, + seq_len: int = 2048, + dp_rank: int = 0, + dp_world_size: int = 1, + infinite: bool = False, + shuffle_seed: int | None = 42, + ) -> None: + ds = load_dataset(dataset_path if dataset_path else dataset_name, split="train") + + if shuffle_seed is not None: + ds = ds.shuffle(shuffle_seed) + + logger.info(f"Loaded preprocessed dataset with {len(ds)} samples") + + self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) + self._tokenizer = tokenizer + self.seq_len = seq_len + self.infinite = infinite + self.dataset_name = dataset_name + self.dp_rank = dp_rank + + # Variables for checkpointing + self._sample_idx = 0 + + def _get_data_iter(self): + # For map-style datasets, resume by skipping to the correct index + # For iterable-style datasets, the underlying iterator already points to the correct index + if isinstance(self._data, Dataset): + if self._sample_idx == len(self._data): + return iter([]) + else: + return iter(self._data.skip(self._sample_idx)) + + return iter(self._data) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_idx += 1 + + keys = list(sample.keys()) + + inputs = torch.LongTensor(sample["inputs"]) + labels = torch.LongTensor(sample["labels"] if "labels" in keys else sample["inputs"]) + images = sample["images"] + + labels = torch.roll(labels, shifts=-1, dims=0) + labels[-1] = -100 + + args = { + "input": inputs, + } + if "position_ids" in keys: + args["position_ids"] = torch.LongTensor(sample["position_ids"]) + if "sequence_lengths" in keys: + args["sequence_lengths"] = torch.LongTensor(sample["sequence_lengths"]) + + yield args, labels, images + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + self._sample_idx = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + if not isinstance(self._data, Dataset): + if hasattr(self._data, "set_epoch") and hasattr( + self._data, "epoch" + ): + self._data.set_epoch(self._data.epoch + 1) + + def __len__(self): + return len(self._data) + + def load_state_dict(self, state_dict): + + if isinstance(self._data, Dataset): + self._sample_idx = state_dict["sample_idx"] + else: + assert "data" in state_dict + self._data.load_state_dict(state_dict["data"]) + + def state_dict(self): + _state_dict = {} + + if isinstance(self._data, Dataset): + _state_dict["sample_idx"] = self._sample_idx + else: + # Save the iterable dataset's state to later efficiently resume from it + # https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration + _state_dict["data"] = self._data.state_dict() + + return _state_dict + + +def collate_fn_multimodal(batch): + inputs, labels, images = zip(*batch) + + expected_len = len(inputs[0]["input"]) + for i, (input_item, label_item, image_item) in enumerate(batch): + input_len = len(input_item["input"]) + label_len = len(label_item) + + if input_len != expected_len or label_len != expected_len: + raise ValueError( + f"All tensors in the batch must have the same length. " + f"Expected length {expected_len} (from item 0), but item {i} has " + f"an input length of {input_len} and a label length of {label_len}." + ) + + args = {"input": torch.stack([x["input"] for x in inputs]), "images": images} + + if "position_ids" in inputs[0]: + args["position_ids"] = torch.stack([x["position_ids"] for x in inputs]) + if "sequence_lengths" in inputs[0]: + args["sequence_lengths"] = [x["sequence_lengths"] for x in inputs] + + images_list = images + #images_list = [x["image"][0] for x in images] + + labels_tensor = torch.stack(list(labels)) + + return args, labels_tensor + + def build_preprocessed_dataloader( dp_world_size: int, dp_rank: int, @@ -176,3 +306,36 @@ def build_preprocessed_dataloader( batch_size=batch_size, collate_fn=collate_fn, ) + + +def build_preprocessed_multimodal_dataloader( + dp_world_size: int, + dp_rank: int, + tokenizer: Tokenizer, + job_config: JobConfig, + infinite: bool = True, +) -> ParallelAwareDataloader: + """Build a data loader for HuggingFace datasets.""" + dataset_name = job_config.training.dataset + dataset_path = job_config.training.dataset_path + batch_size = job_config.training.local_batch_size + seq_len = job_config.training.seq_len + + ds = PreprocessedMultimodalDataset( + dataset_name=dataset_name, + dataset_path=dataset_path, + tokenizer=tokenizer, + seq_len=seq_len, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + infinite=infinite, + ) + + return ParallelAwareDataloader( + dataset=ds, + dp_rank=dp_rank, + dp_world_size=dp_world_size, + batch_size=batch_size, + collate_fn=collate_fn_multimodal, + ) + diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index 4a4041ddf2..a7274fe42f 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -9,4 +9,5 @@ # will be called. import torchtitan.models.deepseek_v3 # noqa: F401 import torchtitan.models.llama3 # noqa: F401 +import torchtitan.models.mistral3 import torchtitan.models.qwen3 # noqa: F401 diff --git a/torchtitan/models/mistral3/__init__.py b/torchtitan/models/mistral3/__init__.py new file mode 100644 index 0000000000..437fe9a9cf --- /dev/null +++ b/torchtitan/models/mistral3/__init__.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025, Anthropic Research Labs +# All rights reserved. + +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.datasets.dataloader import build_dataloader + +#from .model.configuration_pixtral import PixtralVisionConfig +from .model.model import VLMArgs, VLM + +from .infra.parallelize import parallelize_mistral3 +from .infra.pipeline import pipeline_mistral3 + + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize import parallelize_mistral3 +from .infra.pipeline import pipeline_mistral3 +from .model.model import VLM + + +__all__ = [ + "parallelize_mistral3", + "pipeline_mistral3", + "VLMArgs", + "VLM", + "mistral3_configs", +] + +# Define model configurations +mistral3_configs = { + "24B_samplepacked": VLMArgs( + # vision encoder part + vision_embed_dim=1024, + vision_num_layers=24, + vision_num_heads=16, + vision_feature_layer=-1, + patch_size=14, + image_size=1540, + in_channels=3, + spatial_merge_size=2, + + # projection part + num_layers_projection=8, + projector_hidden_act="gelu", + multimodal_projector_bias=False, + + # decoder part + decoder_embed_dim=5120, + decoder_num_layers=40, + decoder_num_heads=32, + decoder_num_kv_heads=8, + fusion_interval=8, + image_token_index=10, + + # common part + vocab_size=131072, + multiple_of=256, + ffn_dim_multiplier=None, + norm_eps=1e-5, + rope_theta=1000000000.0, + max_seq_len=131072, + use_flex_attn=True, + attn_mask_type="block_causal_by_sequence_lengths", + ), + "debug_samplepacked": VLMArgs( + # vision encoder part + vision_embed_dim=1024, + vision_num_layers=24, + vision_num_heads=16, + vision_feature_layer=-2, + patch_size=14, + image_size=1540, + in_channels=3, + spatial_merge_size=2, + + # projection part + num_layers_projection=8, + projector_hidden_act="gelu", + multimodal_projector_bias=False, + + # decoder part + decoder_embed_dim=5120, + decoder_num_layers=1, + decoder_num_heads=32, + decoder_num_kv_heads=8, + fusion_interval=8, + image_token_index=10, + + # common part + vocab_size=131072, + multiple_of=256, + ffn_dim_multiplier=None, + norm_eps=1e-5, + rope_theta=1000000000.0, + max_seq_len=131072, + use_flex_attn=True, + attn_mask_type="block_causal_by_sequence_lengths", + ), + "24B": VLMArgs( + # vision encoder part + vision_feature_layer=-1, + attn_mask_type="block_causal", + ), + "debug": VLMArgs( + # vision encoder part + vision_embed_dim=1024, + vision_num_layers=24, + vision_num_heads=16, + vision_feature_layer=-2, + patch_size=14, + image_size=1540, + in_channels=3, + spatial_merge_size=2, + + # projection part + num_layers_projection=8, + projector_hidden_act="gelu", + multimodal_projector_bias=False, + + # decoder part + decoder_embed_dim=5120, + decoder_num_layers=2, + decoder_num_heads=32, + decoder_num_kv_heads=8, + fusion_interval=8, + image_token_index=10, + + # common part + vocab_size=131072, + multiple_of=256, + ffn_dim_multiplier=None, + norm_eps=1e-5, + rope_theta=1000000000.0, + max_seq_len=131072, + use_flex_attn=True, + attn_mask_type="block_causal_by_sequence_lengths", + ), +} + + +# Register the model +register_train_spec( + TrainSpec( + name="mistral3", + parallelize_fn=parallelize_mistral3, + model_cls=VLM, + model_args=mistral3_configs, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + pipelining_fn=pipeline_mistral3, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_dataloader, + ) +) \ No newline at end of file diff --git a/torchtitan/models/mistral3/infra/parallelize.py b/torchtitan/models/mistral3/infra/parallelize.py new file mode 100644 index 0000000000..8c9378717f --- /dev/null +++ b/torchtitan/models/mistral3/infra/parallelize.py @@ -0,0 +1,409 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +from collections import defaultdict + +import torch +import torch.nn as nn + +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + fully_shard, + MixedPrecisionPolicy, +) +from torch.distributed._composable.replicate import replicate +from torch.distributed._tensor import Replicate, Shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +#from torchtitan.logging import logger + +from torchtitan.tools.logging import logger +from torchtitan.distributed import ParallelDims + + +def parallelize_mistral3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + + """ + + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + + if parallel_dims.tp_enabled: + """ + if ( + job_config.experimental.enable_async_tensor_parallel + and not job_config.training.compile + ): + raise RuntimeError("Async TP requires --training.compile") + """ + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_tp( + model, + world_mesh["tp"], + loss_parallel=False, + enable_float8=False, + enable_async_tp=False, + #enable_async_tp=job_config.experimental.enable_async_tensor_parallel, + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if job_config.training.compile: + apply_compile(model) + + if ( + parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled + ): # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.experimental.enable_compiled_autograd, + ) + return model + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model.language_model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + use_local_output=False, + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears + if enable_float8: + # TODO(vkuzo): once float8 configuration supports delayed scaling, + # add a check here to enforce supported float8 all-gather configurations + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for layer_id, transformer_block in model.language_model.layers.items(): + layer_plan = { + "ln_attn": SequenceParallel(), + "attn": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attn.wq": colwise_parallel(), + "attn.wk": colwise_parallel(), + "attn.wv": colwise_parallel(), + "attn.wo": rowwise_parallel(output_layouts=Shard(1), use_local_output=False), + "ln_mlp": SequenceParallel(), + "mlp": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "mlp.w1": colwise_parallel(), + "mlp.w2": rowwise_parallel(output_layouts=Shard(1), use_local_output=False), + "mlp.w3": colwise_parallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + #logger.info( + # f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" + # "Tensor Parallelism to the model" + #) + + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, +} + + +def _apply_ac_to_transformer_block(module: nn.Module, ac_config): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, + ) + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac(model: nn.Module, ac_config): + """Apply activation checkpointing to the model.""" + for layer_id, transformer_block in model.language_model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block(transformer_block, ac_config) + model.language_model.layers.register_module(layer_id, transformer_block) + + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + + +def apply_compile(model: nn.Module): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.language_model.layers.named_children(): + transformer_block = torch.compile(transformer_block, fullgraph=True) + model.language_model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile") + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + for layer_id, transformer_block in model.language_model.layers.items(): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.language_model.layers) - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, + enable_compiled_autograd: bool, +): + if enable_compile: + if enable_compiled_autograd: + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + ) + else: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") diff --git a/torchtitan/models/mistral3/infra/pipeline.py b/torchtitan/models/mistral3/infra/pipeline.py new file mode 100644 index 0000000000..f67cc1f64a --- /dev/null +++ b/torchtitan/models/mistral3/infra/pipeline.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D pipeline parallelism to the Llama model. + +import copy +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import _PipelineSchedule, get_schedule_class + +from torchtitan.config_manager import JobConfig +from torchtitan.tools.logging import logger +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.pipeline import ( + build_pipeline_schedule, + generate_split_points, + stage_ids_this_rank, +) + +from ..model.args import VLMArgs + + +DeviceType = Union[int, str, torch.device] + + +def pipeline_mistral3( + model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: VLMArgs, + loss_fn: Callable[..., torch.Tensor], +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + stages, models = pipeline_llama_manual_split( + model.language_model, pp_mesh, parallel_dims, job_config, device, model_config + ) + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, models, has_first_stage, has_last_stage + + +def pipeline_llama_manual_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: VLMArgs, +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. + + It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. + + The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD + parallelism. + """ + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + + splits = ( + job_config.experimental.pipeline_parallel_split_points + or generate_split_points(job_config, parallel_dims.pp, model_config.n_layers) + ) + + def _build_stage( + stage_idx: int, + start_layer: Optional[str], + stop_layer: Optional[str], + is_first: bool = False, + is_last: bool = False, + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + if not is_first: + model.tok_embeddings = None + + drop_layers = start_layer is not None + for name in list(model.layers.keys()): + # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) + if f"layers.{name}" == start_layer: + drop_layers = False + if f"layers.{name}" == stop_layer: + drop_layers = True + if drop_layers: + del model.layers[name] + + if not is_last: + model.norm = None + model.output = None + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(splits) + 1 + stage_idx = pp_rank + + stages = [] + models = [] + + schedule_class = get_schedule_class( + job_config.experimental.pipeline_parallel_schedule + ) + style = "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): + start_layer = splits[stage_idx - 1] if stage_idx > 0 else None + stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None + stage, model_chunk = _build_stage( + stage_idx, + start_layer, + stop_layer, + is_first=stage_idx == 0, + is_last=stage_idx == num_stages - 1, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx}" + f" with start_layer {start_layer}, stop_layer {stop_layer}" + ) + stages.append(stage) + models.append(model_chunk) + return stages, models diff --git a/torchtitan/models/mistral3/model/args.py b/torchtitan/models/mistral3/model/args.py new file mode 100644 index 0000000000..bd23edf76f --- /dev/null +++ b/torchtitan/models/mistral3/model/args.py @@ -0,0 +1,115 @@ + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchtitan.protocols.train_spec import ModelProtocol +from torchtitan.models.attention import build_attention, init_attention_mask + +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.config_manager import JobConfig + + +from dataclasses import dataclass + +from dataclasses import dataclass + +from torch import nn + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs + +@dataclass +class VLMArgs(BaseModelArgs): + # vision encoder part + vision_embed_dim: int = 1024 + vision_num_layers: int = 24 + vision_num_heads: int = 16 + vision_feature_layer: int = -1 + patch_size: int = 14 + image_size: int = 1540 + in_channels: int = 3 + # For merging patches + spatial_merge_size: int = 2 + + # projection part + num_layers_projection: int = 8 + projector_hidden_act: str = "gelu" + multimodal_projector_bias: bool = False + + # decoder part + decoder_embed_dim: int = 5120 + decoder_num_layers: int = 40 + decoder_num_heads: int = 32 + decoder_num_kv_heads: int = 8 + fusion_interval: int = 8 # Interval for fusion of vision features into text model + image_token_index: int = 10 # Token ID representing an image in the text + + # common part + vocab_size: int = 131072 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 1000000000.0 + max_seq_len: int = 131072 + activation: nn.Module = nn.SiLU() + depth_init: bool = True + norm_type: str = "rmsnorm" + + n_layers: int = 40 + n_heads: int = 32 + n_embd: int = 5120 + dim: int = 4096 + + use_flex_attn: bool = False + attn_mask_type: str = "block_causal_by_sequence_lengths" + eos_id: int = 0 + image_token_id: int = 10 + + def update_from_config( + self, job_config: JobConfig, tokenizer: BaseTokenizer + ) -> None: + self.vocab_size = tokenizer.get_vocab_size() + self.max_seq_len = job_config.training.seq_len + self.eos_id = tokenizer.eos_id + + if job_config.activation_checkpoint.mode == "selective" and self.use_flex_attn: + raise ValueError( + "FlexAttention is not compatible with selective AC yet. " + "See https://github.com/pytorch/pytorch/issues/147879" + ) + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise ValueError( + "FlexAttention is not compatible with CP yet. " + "We are still working on this." + ) + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + nparams = sum(p.numel() for p in model.parameters()) + nparams_embedding = sum( + sum(p.numel() for p in m.parameters()) + for m in model.children() + if isinstance(m, nn.Embedding) + ) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t + + return nparams, num_flops_per_token \ No newline at end of file diff --git a/torchtitan/models/mistral3/model/configuration_pixtral.py b/torchtitan/models/mistral3/model/configuration_pixtral.py new file mode 100644 index 0000000000..7b0fad0052 --- /dev/null +++ b/torchtitan/models/mistral3/model/configuration_pixtral.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pixtral model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class PixtralVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PixtralVisionModel`]. It is used to instantiate an + Pixtral vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to the vision encoder used by Pixtral-12B. + + e.g. [pixtral-hf/pixtral-9b](https://huggingface.co/pixtral-hf/pixtral-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of input channels in the input images. + image_size (`int`, *optional*, defaults to 1024): + Max dimension of the input images. + patch_size (`int`, *optional*, defaults to 16): + Size of the image patches. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + Activation function used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the attention layers. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + Example: + + ```python + >>> from transformers import PixtralVisionModel, PixtralVisionConfig + + >>> # Initializing a Pixtral-12B style configuration + >>> config = PixtralVisionConfig() + + >>> # Initializing a model (with randomly initialized weights) from the configuration + >>> model = PixtralVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "pixtral" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + attention_dropout=0.0, + rope_theta=10000.0, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.rope_theta = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.initializer_range = initializer_range + + +__all__ = ["PixtralVisionConfig"] diff --git a/torchtitan/models/mistral3/model/model.py b/torchtitan/models/mistral3/model/model.py new file mode 100644 index 0000000000..3ec8fa857c --- /dev/null +++ b/torchtitan/models/mistral3/model/model.py @@ -0,0 +1,821 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_tensor +from transformers.image_utils import load_image + +from torchtitan.protocols.train_spec import ModelProtocol +from torchtitan.models.attention import build_attention, init_attention_mask + +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.config_manager import JobConfig + +from .args import VLMArgs + + +def build_norm(norm_type: str, dim: int, eps: float = 1e-6, device: torch.device = None): + """ + Builds the specified normalization layer based on the norm_type. + + Args: + norm_type (str): The type of normalization layer to build. + Supported types: layernorm, np_layernorm, rmsnorm + dim (int): The dimension of the normalization layer. + eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. + + Returns: + The built normalization layer. + + Raises: + NotImplementedError: If an unknown norm_type is provided. + """ + norm_type = norm_type.lower() # Normalize to lowercase + + if norm_type == "layernorm": + return nn.LayerNorm(dim, eps=eps, bias=False) + elif norm_type == "np_layernorm": + return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) + elif norm_type == "rmsnorm": + return RMSNorm(dim, eps=eps, device=device) + else: + raise NotImplementedError(f"Unknown norm_type: '{norm_type}'") + + +class RMSNorm(nn.Module): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) # type: ignore + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + #logger.info(freqs_cis.shape) + #logger.info(x.shape) + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + position_ids (torch.Tensor, optional): Custom position IDs of shape [batch_size, seq_len]. + If provided, will use these to index into freqs_cis. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + + if position_ids is not None: + gathered_freqs = freqs_cis[position_ids] # [bs, seqlen, head_dim/2] + gathered_freqs = gathered_freqs.unsqueeze(2) # [bs, seqlen, 1, head_dim/2] + + xq_out = torch.view_as_real(xq_ * gathered_freqs).flatten(3) + xk_out = torch.view_as_real(xk_ * gathered_freqs).flatten(3) + + return xq_out.type_as(xq), xk_out.type_as(xk) + else: + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, num_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=num_rep)""" + bsz, seq_len, num_kv_heads, head_dim = x.shape + if num_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bsz, seq_len, num_kv_heads, num_rep, head_dim) + .reshape(bsz, seq_len, num_kv_heads * num_rep, head_dim) + ) + + +class Mistral3PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__(self, config: VLMArgs): + super().__init__() + self.config = config + + hidden_size = config.vision_embed_dim + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False) + + + def init_weights(self): + """ + Initialize weights following the Llama3 pattern. + """ + # Initialize merging layer with truncated normal + nn.init.trunc_normal_(self.merging_layer.weight, mean=0.0, std=0.02) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes + ] + + tokens_per_image = [h * w for h, w in image_sizes] + d = image_features.shape[-1] + + permuted_tensor = [] + for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold( + image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size + ) + grid = grid.view(d * self.spatial_merge_size**2, -1).t() + permuted_tensor.append(grid) + + image_features = torch.cat(permuted_tensor, dim=0) + image_features = self.merging_layer(image_features) + + return image_features.unsqueeze(0) + + + +class Mistral3MultiModalProjector(nn.Module): + def __init__(self, config: VLMArgs): + super().__init__() + self.norm = nn.RMSNorm(config.vision_embed_dim, eps=config.norm_eps, device=torch.cuda.current_device()) + self.patch_merger = Mistral3PatchMerger(config) + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_embed_dim * num_feature_layers, + config.decoder_embed_dim, + bias=config.multimodal_projector_bias, + ) + self.act = config.activation + self.linear_2 = nn.Linear( + config.decoder_embed_dim, config.decoder_embed_dim, bias=config.multimodal_projector_bias + ) + + def init_weights(self): + """ + Initialize weights following the Llama3 pattern. + """ + # Initialize norm layer + if hasattr(self.norm, 'reset_parameters'): + self.norm.reset_parameters() + + # Initialize patch merger + if hasattr(self.patch_merger, 'init_weights'): + self.patch_merger.init_weights() + + # Initialize linear layers with truncated normal + for linear in (self.linear_1, self.linear_2): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + if linear.bias is not None: + nn.init.zeros_(linear.bias) + + def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor): + image_features = self.norm(image_features) + + image_features = self.patch_merger(image_features, image_sizes) + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + + +class Attention(nn.Module): + + def __init__(self, config: VLMArgs, is_vision=True): + super().__init__() + if is_vision: + self.num_heads = config.vision_num_heads + self.num_kv_heads = config.vision_num_heads + self.head_dim = config.vision_embed_dim // config.vision_num_heads + self.embed_dim = config.vision_embed_dim + self.is_causal = False + else: + self.num_heads = config.decoder_num_heads + self.num_kv_heads = ( + config.decoder_num_heads if config.decoder_num_kv_heads is None else config.decoder_num_kv_heads + ) + self.head_dim = config.decoder_embed_dim // config.decoder_num_heads + self.embed_dim = config.decoder_embed_dim + self.is_causal = True + + self.num_rep = self.num_heads // self.num_kv_heads + + + self.wq = nn.Linear(self.embed_dim, int(self.num_heads * self.head_dim * 0.8), bias=False) + self.wk = nn.Linear(self.embed_dim, int(self.num_kv_heads * self.head_dim * 0.8), bias=False) + self.wv = nn.Linear(self.embed_dim, int(self.num_kv_heads * self.head_dim * 0.8), bias=False) + self.wo = nn.Linear(int(self.num_heads * self.head_dim * 0.8), self.embed_dim, bias=False) + + self.sdpa = build_attention(True, config.attn_mask_type) + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward(self, x: torch.Tensor, freqs_cis: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor, optional): Precomputed frequency tensor. + position_ids (torch.Tensor, optional): Custom position ids tensor of shape [batch, seq_len]. + + Returns: + torch.Tensor: Output tensor after attention. + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `num_heads` (or `num_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, 128) + xk = xk.view(bs, seqlen, -1, 128) + xv = xv.view(bs, seqlen, -1, 128) + + if freqs_cis is not None: + # Apply RoPE with position_ids if provided + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, position_ids=position_ids) + + + # repeat k/v heads if num_kv_heads < num_heads + keys = repeat_kv(xk, self.num_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.num_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + #output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv) + #output = self.sdpa(xq, xk, xv) + output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + + output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module for the decoder. It's different from the one in the encoder. + This is the component which is originally used in Mistral3/Llama3. + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + hidden_dim = 32768 + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + +class TransformerBlock(nn.Module): + def __init__( + self, + config: VLMArgs, + ): + super().__init__() + self.attn = Attention(config, is_vision=False) + #self.ln_attn = build_norm("rmsnorm", config.decoder_embed_dim, config.norm_eps) + self.ln_attn = nn.RMSNorm(config.decoder_embed_dim, config.norm_eps, device=torch.cuda.current_device()) + self.mlp = FeedForward( + dim=config.decoder_embed_dim, + hidden_dim=4 * config.decoder_embed_dim, + multiple_of=config.multiple_of, + ffn_dim_multiplier=config.ffn_dim_multiplier, + ) + #self.ln_mlp = build_norm("rmsnorm", config.decoder_embed_dim, config.norm_eps) + self.ln_mlp = nn.RMSNorm(config.decoder_embed_dim, config.norm_eps, device=torch.cuda.current_device()) + + self.image_token_id = config.image_token_id + + def init_weights(self): + """ + Initialize weights following the Llama3 pattern. + """ + # Initialize attention and feedforward components + self.attn.init_weights(0.02) # Use standard init_std for attention + self.mlp.init_weights(0.02) # Use standard init_std for feedforward + + # Initialize norm layers + for norm in (self.ln_attn, self.ln_mlp): + norm.reset_parameters() + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + **kwargs: Dict, + ): + # Handle custom position_ids if provided + if position_ids is not None: + # Custom handling for position_ids + # We need to index into freqs_cis with the position_ids + # First, we'll do a custom reshape_for_broadcast implementation that uses position_ids + x_norm = self.ln_attn(x) + # Get the appropriate freqs_cis based on position_ids + x = x + self.attn(x_norm, freqs_cis, position_ids=position_ids) + else: + # Standard forwarding without custom position_ids + x = x + self.attn(self.ln_attn(x), freqs_cis) + + x = x + self.mlp(self.ln_mlp(x)) + return x + +class Transformer(nn.Module): + """Decoder multimodal model for Mistral3. + + Args: + config (VLMArgs): configs for the model. + """ + + def __init__(self, config: VLMArgs): + super().__init__() + + self.register_buffer( + "freqs_cis", self._precompute_freqs_cis(config), persistent=True + ) + + self.layers = nn.ModuleDict() + for idx in range(config.decoder_num_layers): + # define a llama3-like decoder layer + decoder_layer = TransformerBlock(config) + self.layers[str(idx)] = decoder_layer + + self.tok_embeddings = nn.Embedding(131072, config.decoder_embed_dim) + self.norm = nn.RMSNorm(config.decoder_embed_dim, eps=config.norm_eps, device=torch.cuda.current_device()) + self.output = nn.Linear( + config.decoder_embed_dim, 131072, bias=False + ) + + self.image_token_id = config.image_token_id + + def init_weights(self): + """ + Initialize weights following the Llama3 pattern. + """ + # Initialize token embeddings + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + + # Initialize all layers + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + + # Initialize norm layer + if self.norm is not None: + self.norm.reset_parameters() + + # Initialize output layer with truncated normal + if self.output is not None: + final_out_std = self.output.in_features**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self, config) -> torch.Tensor: + return precompute_freqs_cis( + int(config.decoder_embed_dim // config.decoder_num_heads * 0.8), + # Need to compute until at least the max token limit for generation + # (use 2x max sequence length to be safe) + config.max_seq_len, + config.rope_theta, + ) + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ) -> torch.BoolTensor: + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + def forward( + self, + tokens: torch.Tensor, + *, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + image_features: Optional[list] = None, + ) -> torch.Tensor: + + # input tensor of shape [b, s] + bsz, seq_len = tokens.shape + + # shape: [b, s, d] + if inputs_embeds is None: + h = self.tok_embeddings(tokens) + else: + h = inputs_embeds + + if image_features is not None: + + if isinstance(h, DTensor): + h_full = h.redistribute(h.device_mesh, [Replicate()]).to_local() + new_h_full = h_full.clone() # Create a copy to modify + + for i, i_image_features in enumerate(image_features): + if i_image_features is not None: + #image_feat = i_image_features + image_feat = i_image_features.unsqueeze(0) + special_image_mask = self.get_placeholder_mask( + tokens[i].unsqueeze(0), h_full[i].unsqueeze(0), image_feat + ) + # Use torch.where instead of masked_scatter + new_h_full[i] = h_full[i].masked_scatter(special_image_mask, image_feat) + + + # Convert back to DTensor + new_h_full = new_h_full.to(h.device) + #h_replicated = distribute_tensor(new_h_full, h.device_mesh, [Replicate()]) + h_replicated = DTensor.from_local(new_h_full, h.device_mesh, placements=[Replicate()]) + + h = h_replicated.redistribute(h.device_mesh, [Shard(1)]) + else: + for i, i_image_features in enumerate(image_features): + if i_image_features is not None: + + #image_features = i_image_features + image_features = i_image_features.unsqueeze(0) + special_image_mask = self.get_placeholder_mask( + input_ids=tokens[i].unsqueeze(0), inputs_embeds=h[i].unsqueeze(0), image_features=image_features + ) + h[i] = h[i].masked_scatter(special_image_mask, image_features) + + if image_features is None: + print("image features is None") + + + # Setup freqs_cis based on position_ids or sequence length + if position_ids is not None: + # Use custom position_ids to index into freqs_cis + # We still need freqs_cis with the right device/dtype + freqs_cis = self.freqs_cis + else: + # Default: use standard positions based on sequence length + freqs_cis = self.freqs_cis + + for layer in self.layers.values(): + # shape: [b, s, d] + h = layer( + h, + freqs_cis=freqs_cis, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + position_ids=position_ids, + ) + + # shape: [b, s, d] + h = self.norm(h) + output = self.output(h) + + return output + + +class VLM(nn.Module, ModelProtocol): + """ + Mistral3 model which consists of a vision backbone and a language model. + + Args: + config (VLMArgs): Configuration for the model. + """ + + def __init__(self, config: VLMArgs): + super().__init__() + self.config = config + + # Language model decoder + self.language_model = Transformer(config) + + # Special token for representing images in the text + self.image_token_index = config.image_token_index + + self.vision_model_initialized = False + + from .modeling_pixtral import PixtralVisionModel, PixtralVisionConfig + + # Create a PixtralVisionConfig based on the ModelArgs + pixtral_config = PixtralVisionConfig( + hidden_size=config.vision_embed_dim, + intermediate_size=4 * config.vision_embed_dim, # Standard multiplier + num_hidden_layers=config.vision_num_layers, + num_attention_heads=config.vision_num_heads, + num_channels=config.in_channels, + image_size=config.image_size, + patch_size=config.patch_size, + hidden_act="silu", # Standard activation + attention_dropout=0.0, # No dropout by default + rope_theta=config.rope_theta, + initializer_range=0.02 # Standard initialization + ) + + self.vision_tower = PixtralVisionModel(pixtral_config) + + # Add projection to connect to the decoder + self.multi_modal_projector = Mistral3MultiModalProjector(config) + + from transformers import AutoProcessor + self.preprocessor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503", use_fast=True) + + self.initialized_vision=False + + def init_weights( + self, + buffer_device: Optional[torch.device] = None, + ): + + buffer_device = buffer_device or self.language_model.freqs_cis.device + with torch.device(buffer_device): + self.language_model._precompute_freqs_cis(self.config) + + + """ + + # Initialize language model components + if hasattr(self.language_model, 'init_weights'): + self.language_model.init_weights() + + ## Initialize vision tower if it exists + if hasattr(self, 'vision_tower') and self.vision_tower is not None: + if hasattr(self.vision_tower, 'init_weights'): + self.vision_tower.init_weights() + + ## Initialize multimodal projector if it exists + if hasattr(self, 'multi_modal_projector') and self.multi_modal_projector is not None: + if hasattr(self.multi_modal_projector, 'init_weights'): + self.multi_modal_projector.init_weights() + """ + + if not self.initialized_vision: + from transformers import AutoModelForImageTextToText + hf_model = AutoModelForImageTextToText.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503", device_map='cpu', torch_dtype=torch.bfloat16) + + vision_tower_device = self.vision_tower.device + + print(f"before: {self.vision_tower.dtype}") + + hf_model.vision_tower = self.vision_tower.to(vision_tower_device, dtype=torch.bfloat16) + hf_model.multi_modal_projector = self.multi_modal_projector.to(vision_tower_device, dtype=torch.bfloat16) + + self.vision_tower = hf_model.vision_tower.to(vision_tower_device, dtype=torch.float32) + self.multi_modal_projector = hf_model.multi_modal_projector.to(vision_tower_device, dtype=torch.float32) + + + print("did the thing") + print(f"after: {self.vision_tower.dtype}") + + self.initialized_vision=True + + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + image_sizes: torch.Tensor, + **kwargs, + ): + kwargs = {k: v for k, v in kwargs.items() if v is not None} + with torch.no_grad(): + image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=False) + # If we have one vision feature layer, return the corresponding hidden states, + # otherwise, select the hidden states of each feature layer and concatenate them + if isinstance(vision_feature_layer, int): + selected_image_feature = image_outputs.last_hidden_state #[vision_feature_layer] + else: + hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer] + selected_image_feature = torch.cat(hs_pool, dim=-1) + + image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes) + return image_features + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + sequence_lengths: list[torch.Tensor] | None = None, + image_features=None, + images: Optional[list] = None + ): + + image_features = None + all_image_features = [] + + if images is None: + images= [] + + for i, batch in enumerate(images): + i_image_features = None + + if batch is not None: + with torch.no_grad(): + i_image_features = None + image_features_batch = [] + images = [load_image(im) if isinstance(im, str) else im for im in batch] + + image_inputs = self.preprocessor.image_processor(images, patch_size=self.config.patch_size * 2) + + #image_encoder_outputs = self.get_image_features(image_inputs["pixel_values"].to(self.vision_tower.device, dtype=torch.float16), 2, image_inputs["image_sizes"]) + image_encoder_outputs = self.get_image_features(image_inputs["pixel_values"].to(self.vision_tower.device, dtype=self.vision_tower.dtype), -1, image_inputs["image_sizes"]) + + # Collect image features from all images in the batch + i_image_features = image_encoder_outputs + + + #if i_image_features.shape[0] > 1: + # i_image_features = torch.cat(i_image_features, dim=0) # Shape: (1, sum_of_image_patches_of_all_images) + # i_image_features = i_image_features.unsqueeze(0) + + + all_image_features.append(i_image_features) + + #else: + # return NotImplementedError("Position IDs are required for multimodal input.") + + if self.config.use_flex_attn: + init_attention_mask(input_ids, eos_id=self.config.eos_id, sequence_lengths=sequence_lengths) + + if position_ids is not None: + if all_image_features: + logits = self.language_model( + tokens=input_ids, + encoder_mask=None, + position_ids=position_ids, + image_features=all_image_features, + ) + else: + logits = self.language_model( + tokens=input_ids, + encoder_mask=None, + position_ids=position_ids, + ) + else: + if all_image_features: + logits = self.language_model( + tokens=input_ids, + encoder_mask=None, + image_features=all_image_features, + ) + else: + logits = self.language_model( + tokens=input_ids, + encoder_mask=None, + ) + + return logits + + @classmethod + def from_model_args(cls, model_args: VLMArgs) -> "Transformer": + return cls(model_args) \ No newline at end of file diff --git a/torchtitan/models/mistral3/model/modeling_pixtral.py b/torchtitan/models/mistral3/model/modeling_pixtral.py new file mode 100644 index 0000000000..84c877ca5f --- /dev/null +++ b/torchtitan/models/mistral3/model/modeling_pixtral.py @@ -0,0 +1,532 @@ +# coding=utf-8 +# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Pixtral model.""" + +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_pixtral import PixtralVisionConfig + + +logger = logging.get_logger(__name__) + + +def position_ids_in_meshgrid(patch_embeds_list, max_width): + positions = [] + for patch in patch_embeds_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + + +class PixtralRotaryEmbedding(nn.Module): + """ + The key with pixtral embedding is just that you have a frequency for each pixel positions. + If you have height x width pixels (or embedding pixels), then the frequency used for ROPE + is given by indexing the pre_computed frequency on the width and height. + + What you output is of dimension (batch, height * width, dim) with dim the embed dim. + + This simply means that for each image hidden state, you are going to add + a corresponding positional embedding, based on its index in the grid. + """ + + def __init__(self, config, device=None): + super().__init__() + self.rope_type = "default" + self.dim = config.head_dim + self.base = config.rope_theta + max_patches_per_side = config.image_size // config.patch_size + freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + + h = torch.arange(max_patches_per_side, device=freqs.device) + w = torch.arange(max_patches_per_side, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes + # Different from paper, but it uses a different permutation in order to obtain the same calculation + + # TODO maybe make it torch compatible later on. We can also just slice + self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + freqs = self.inv_freq[position_ids] + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + emb = freqs + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class PixtralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, patches, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral +class PixtralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral +class PixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + PixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class PixtralAttentionLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.feed_forward = PixtralMLP(config) + self.attention = PixtralAttention(config) + self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + hidden_states, attn_weights = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + return outputs + + +class PixtralTransformer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layers = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(PixtralAttentionLayer(config)) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embeddings which serve as input to the Transformer. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + position_embeddings, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +PIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PixtralVisionConfig`]): + Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +class PixtralPreTrainedModel(PreTrainedModel): + config_class = PixtralVisionConfig + base_model_prefix = "model" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["PixtralAttentionLayer"] + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range + ) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PIXTRAL_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`AutoImageProcessor.__call__`] + for details. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): + The sizes of the images in the batch, being (height, width) for each image. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask + + +@add_start_docstrings( + "The bare Pixtral vision encoder outputting raw hidden-states without any specific head on top.", + PIXTRAL_START_DOCSTRING, +) +class PixtralVisionModel(PixtralPreTrainedModel): + base_model_prefix = "vision_encoder" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.patch_size = config.patch_size + self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5) + self.transformer = PixtralTransformer(config) + self.patch_positional_embedding = PixtralRotaryEmbedding(config) + + self.post_init() + + def get_input_embeddings(self): + return self.patch_conv + + @add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: torch.Tensor, + image_sizes: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + *args, + **kwargs, + ) -> Union[Tuple, BaseModelOutput]: + """ + Returns: + pixel_values: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds = self.patch_conv(pixel_values) + patch_embeds_list = [ + embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)] + for embed, size in zip(patch_embeds, image_sizes) + ] + + # flatten to a single sequence + patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + position_ids = position_ids_in_meshgrid( + patch_embeds_list, max_width=self.config.image_size // self.config.patch_size + ) + position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) + + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + + out = self.transformer( + patch_embeds, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + return out + + +__all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"] diff --git a/torchtitan/models/mistral3/train_configs/mistral24b_debug.toml b/torchtitan/models/mistral3/train_configs/mistral24b_debug.toml new file mode 100644 index 0000000000..4166d13259 --- /dev/null +++ b/torchtitan/models/mistral3/train_configs/mistral24b_debug.toml @@ -0,0 +1,66 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Mistral 24B training" + + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +enable_wandb = true +save_tb_folder = "tb" + +[model] +name = "mistral3" +flavor = "debug_samplepacked" +tokenizer_path = "/home/shared/torchtitan-conversions/mistral_small_3.1_instruct/tokenizer" + +[optimizer] +name = "AdamW" +lr = 2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 0 # lr scheduler warm up + +[training] +local_batch_size = 1 +seq_len = 8000 +max_norm = 1.0 # grad norm clipping +steps = 1000 +compile = false +dataset = "hermes4" +dataset_path = "./multimodal_dataset" +dataset_type = "preprocessed_multimodal" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = true +#initial_load_path = "/home/shared/torchtitan-conversions/mistral_small_3.1_instruct" +folder = "checkpoint" +interval = 500 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] diff --git a/torchtitan/models/mistral3/train_configs/mistral24b_finetuning.toml b/torchtitan/models/mistral3/train_configs/mistral24b_finetuning.toml new file mode 100644 index 0000000000..cba02220f8 --- /dev/null +++ b/torchtitan/models/mistral3/train_configs/mistral24b_finetuning.toml @@ -0,0 +1,66 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Mistral 24B training" + + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +enable_wandb = true +save_tb_folder = "tb" + +[model] +name = "mistral3" +flavor = "24B_samplepacked" +tokenizer_path = "/home/shared/torchtitan-conversions/mistral_small_3.1_instruct/tokenizer" + +[optimizer] +name = "AdamW" +lr = 2e-6 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 0 # lr scheduler warm up + +[training] +local_batch_size = 4 +seq_len = 16384 +max_norm = 1.0 # grad norm clipping +steps = 500 +compile = false +dataset = "hermes4" +dataset_path = "./multimodal_dataset" +dataset_type = "preprocessed_multimodal" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = 1 +tensor_parallel_degree = 8 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = true +initial_load_path = "/home/shared/torchtitan-conversions/mistral_small_3.1_instruct" +folder = "checkpoint" +interval = 100 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] diff --git a/torchtitan/models/mistral3/train_configs/mistral24b_finetuning_textonly.toml b/torchtitan/models/mistral3/train_configs/mistral24b_finetuning_textonly.toml new file mode 100644 index 0000000000..a86a9dd04d --- /dev/null +++ b/torchtitan/models/mistral3/train_configs/mistral24b_finetuning_textonly.toml @@ -0,0 +1,66 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Mistral 24B training" + + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +enable_wandb = true +save_tb_folder = "tb" + +[model] +name = "mistral3" +flavor = "24B_samplepacked" +tokenizer_path = "/home/shared/torchtitan-conversions/mistral_small_3.1_instruct/tokenizer" + +[optimizer] +name = "AdamW" +lr = 2e-6 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 0 # lr scheduler warm up + +[training] +local_batch_size = 1 +seq_len = 8000 +max_norm = 1.0 # grad norm clipping +steps = 100 +compile = false +dataset = "hermes4text" +dataset_path = "./dataset" +dataset_type = "preprocessed" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = true +initial_load_path = "/home/shared/torchtitan-conversions/mistral_small_3.1_instruct" +folder = "checkpoint" +interval = 500 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] diff --git a/torchtitan/models/mistral3/train_configs/mistral24b_inference.toml b/torchtitan/models/mistral3/train_configs/mistral24b_inference.toml new file mode 100644 index 0000000000..514c79036d --- /dev/null +++ b/torchtitan/models/mistral3/train_configs/mistral24b_inference.toml @@ -0,0 +1,66 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Mistral 24B training" + + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = false +enable_wandb = true +save_tb_folder = "tb" + +[model] +name = "mistral3" +flavor = "24B" +tokenizer_path = "/home/shared/torchtitan-conversions/mistral_small_3.1_instruct/tokenizer" + +[optimizer] +name = "AdamW" +lr = 2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 0 # lr scheduler warm up + +[training] +local_batch_size = 1 +seq_len = 16384 +max_norm = 1.0 # grad norm clipping +steps = 100 +compile = false +dataset = "hermes4" +dataset_path = "./multimodal_dataset" +dataset_type = "preprocessed_multimodal" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = true +initial_load_path = "/home/shared/torchtitan-conversions/mistral_small_3.1_instruct" +folder = "checkpoint" +interval = 500 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] diff --git a/torchtitan/models/mistral3/train_configs/mistral24b_sft.toml b/torchtitan/models/mistral3/train_configs/mistral24b_sft.toml new file mode 100644 index 0000000000..6942c0396f --- /dev/null +++ b/torchtitan/models/mistral3/train_configs/mistral24b_sft.toml @@ -0,0 +1,74 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 8 H100 GPUs. + +[job] +dump_folder = "./mistral-small-24b-sft" +description = "Mistral 24B training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = true +wandb_project = "artem_deephermes_24b_simpletrainer" +seed = 42 + +[model] +name = "mistral3" +flavor = "24B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +tokenizer_path = "/home/artem_nous/tt/torchtitan/assets/tokenizer/Mistral-Small-3.1-24B-Base-2503" + +[optimizer] +name = "AdamW" +lr = 1e-3 + +[training] +#base_model_dir = "/home/ubuntu/torchtitan/llama_dcp" +#base_model_dir = "/home/artem_nous/tt/torchtitan/mistral-small-base-dcp" + +batch_size = 1 +seq_len = 1000 +gradient_accumulation_steps = 1 +warmup_steps = 200 # lr scheduler warm up +lr_ramp_type = "cosine" +lr_ramp_end_lr_ratio = 0.0 +max_norm = 1.0 # grad norm clipping +steps = 10000 +compile = "false" +local_batch_size = 1 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[experimental] +context_parallel_degree = 1 +pipeline_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = true +initial_load_path = "/home/artem_nous/tt/torchtitan/mistral-small-base-dcp" +last_save_model_weights_only = false +folder = "checkpoint" +interval_type = "steps" +interval = 72000 +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] +#selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = true +precompute_float8_dynamic_scale_for_fsdp = true + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 diff --git a/torchtitan/models/mistral3/train_configs/mistral24b_sft_vision b/torchtitan/models/mistral3/train_configs/mistral24b_sft_vision new file mode 100644 index 0000000000..d2f7cc9850 --- /dev/null +++ b/torchtitan/models/mistral3/train_configs/mistral24b_sft_vision @@ -0,0 +1,71 @@ +# torchtitan Config.toml +# NOTE: this toml config is a preset for 8 H100 GPUs. + +[job] +dump_folder = "./vision-test" +description = "Mistral 24B training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = true +wandb_project = "artem_deephermes_24b_simpletrainer" +seed = 42 + +[model] +name = "mistral3" +flavor = "24B" +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm +#tokenizer_path = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" +tokenizer_path = "/mnt/weka/home/artem/new_pr_2/simple-trainer/mistral-2503-instruct-hf" +#tokenizer_path = "/mnt/weka/home/artem/new2/simple-trainer/mistral-2503-base-hf" +# converters = "float8" + +[optimizer] +name = "AdamW" +lr = 1e-6 + +[training] +base_model_dir = "/mnt/weka/home/artem/new2/simple-trainer/mistral-dcp-correct" +train_type = "sft_vision" +seed=42 +batch_size = 1 +seq_len = 12000 +gradient_accumulation_steps = 1 +warmup_steps = 200 # lr scheduler warm up +lr_ramp_type = "cosine" +lr_ramp_end_lr_ratio = 0.0 +max_norm = 1.0 # grad norm clipping +steps = 700 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 +compile = false +dataset = "vision" + +[experimental] +context_parallel_degree = 1 +pipeline_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = true +folder = "checkpoint" +interval_type = "steps" +interval = 72000 +model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = 'full' # ['none', 'selective', 'full'] +#selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = true +precompute_float8_dynamic_scale_for_fsdp = true diff --git a/torchtitan/models/qwen3/train_configs/qwen3_8b_finetuning.toml b/torchtitan/models/qwen3/train_configs/qwen3_8b_finetuning.toml index 3a23b48bcd..e962c15bf1 100644 --- a/torchtitan/models/qwen3/train_configs/qwen3_8b_finetuning.toml +++ b/torchtitan/models/qwen3/train_configs/qwen3_8b_finetuning.toml @@ -32,11 +32,14 @@ warmup_steps = 200 # lr scheduler warm up [training] local_batch_size = 1 -seq_len = 8192 +seq_len = 8000 max_norm = 1.0 # grad norm clipping steps = 1000 compile = false -dataset = "c4" +dataset = "hermes4" +dataset_path = "./dataset" +dataset_type = "preprocessed" + [parallelism] data_parallel_replicate_degree = 1