diff --git a/README.md b/README.md index 2b751070f0..99de4d93c7 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,7 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp * Gemma 2 (2B, 9B, 27B) * Gemma 1 (2B, 7B) * Alibaba + * Qwen 2.5 (7B, 14B) * Qwen 3 MoE 2507 (235B, 480B) * Qwen 3 MoE (30B, 235B) * Qwen 3 Dense (0.6B, 1.7B, 4B, 8B, 14B, 32B) diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 31196685dc..d171e1545f 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -11,6 +11,7 @@ The following models are supported: | **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ | | **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ | | **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ | +| **Qwen2.5** | 7B, 14B | √ | √ | √ | √ | | **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ | | **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | | **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | diff --git a/src/MaxText/common_types.py b/src/MaxText/common_types.py index f36b991cef..80f1d907da 100644 --- a/src/MaxText/common_types.py +++ b/src/MaxText/common_types.py @@ -92,6 +92,7 @@ class DecoderBlockType(enum.Enum): GEMMA = "gemma" GEMMA2 = "gemma2" GEMMA3 = "gemma3" + QWEN2 = "qwen2" QWEN3 = "qwen3" QWEN3_MOE = "qwen3_moe" QWEN3_NEXT = "qwen3_next" diff --git a/src/MaxText/integration/tunix/weight_mapping/__init__.py b/src/MaxText/integration/tunix/weight_mapping/__init__.py index 7f7a0dc534..2c218acc56 100644 --- a/src/MaxText/integration/tunix/weight_mapping/__init__.py +++ b/src/MaxText/integration/tunix/weight_mapping/__init__.py @@ -21,6 +21,7 @@ from MaxText.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING +from MaxText.integration.tunix.weight_mapping.qwen2 import QWEN2_VLLM_MAPPING from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING @@ -30,6 +31,8 @@ class StandaloneVllmWeightMapping: def __getattr__(self, name): if name.startswith("llama3.1"): return LLAMA3_VLLM_MAPPING + elif name.startswith("qwen2"): + return QWEN2_VLLM_MAPPING elif name.startswith("qwen3"): return QWEN3_VLLM_MAPPING elif name.startswith("deepseek3"): diff --git a/src/MaxText/integration/tunix/weight_mapping/qwen2.py b/src/MaxText/integration/tunix/weight_mapping/qwen2.py new file mode 100644 index 0000000000..b129f23c6a --- /dev/null +++ b/src/MaxText/integration/tunix/weight_mapping/qwen2.py @@ -0,0 +1,136 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines the weight mapping from MaxText's Qwen2 model to a vLLM-compatible format. + +This module provides the `QWEN2_VLLM_MAPPING` dataclass, which contains all the +necessary configurations to convert MaxText's Qwen2 model weights into a +format that can be loaded by HuggingFace's vLLM. This includes: +- A direct mapping of parameter names. +- Sharding specifications for distributed environments. +""" + +from dataclasses import dataclass + + +@dataclass +class QWEN2_VLLM_MAPPING: + """Mapping MaxText Qwen2 weights to vLLM's Qwen2 weights.""" + + @staticmethod + def to_hf_hook_fns(): + """Returns a dictionary of hook functions to be applied to MaxText weights. + + Returns: + An empty dictionary, as no hook functions are needed for this mapping. + """ + + return {} + + @staticmethod + def to_hf_transpose_keys(): + """Returns a list of keys for weights that need to be transposed. + + Returns: + An empty dictionary, as no keys require transposition for this mapping. + """ + return {} + + @staticmethod + def lora_to_hf_mappings(): + """Provides the mapping for LoRA (Low-Rank Adaptation) weights. + + Returns: + None, as LoRA mappings are not defined for this model. + """ + return None + + @staticmethod + def to_hf_mapping(): + """Mapping from MaxText model to HuggingFace vLLM model. + + Currently, the param mapping conforms to the Tunix API, which combines the + param name & sharding in one dictionary. + This is subject to change in the future where we can decouple the two. + """ + return { + # Token embeddings - shard vocab dimension + "base.token_embedder.embedding": ( + "model.embed.embedding", + ("model", None), + ), + # Final layer norm - no sharding needed + "base.decoder.decoder_norm.scale": ( + "model.norm.scale", + (None,), + ), + # LM head (logits projection) - shard vocab dimension + "base.decoder.logits_dense.kernel": ( + "model.lm_head", + (None, "model"), + ), + # Layer-specific mappings (scanned -> unscanned) + # MLP components - shard hidden dimensions + "base.decoder.layers.mlp.wi_0.kernel": ( + "model.layers.*.mlp.gate_proj.kernel", + (None, "layer", "model"), + ), + "base.decoder.layers.mlp.wi_1.kernel": ( + "model.layers.*.mlp.up_proj.kernel", + (None, "layer", "model"), + ), + "base.decoder.layers.mlp.wo.kernel": ( + "model.layers.*.mlp.down_proj.kernel", + ("model", "layer", None), + ), + # Layer norms - no sharding needed + "base.decoder.layers.pre_self_attention_layer_norm.scale": ( + "model.layers.*.input_layernorm.scale", + (None, "layer"), + ), + "base.decoder.layers.post_self_attention_layer_norm.scale": ( + "model.layers.*.post_attention_layernorm.scale", + (None, "layer"), + ), + # Attention components - shard head dimensions + "base.decoder.layers.self_attention.query.kernel": ( + "model.layers.*.self_attn.q_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.key.kernel": ( + "model.layers.*.self_attn.k_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.value.kernel": ( + "model.layers.*.self_attn.v_proj.kernel", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.out.kernel": ( + "model.layers.*.self_attn.o_proj.kernel", + ("model", "layer", None, None), + ), + # Attention biases + "base.decoder.layers.self_attention.query.bias": ( + "model.layers.*.self_attn.q_proj.bias", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.key.bias": ( + "model.layers.*.self_attn.k_proj.bias", + (None, "layer", "model", None), + ), + "base.decoder.layers.self_attention.value.bias": ( + "model.layers.*.self_attn.v_proj.bias", + (None, "layer", "model", None), + ), + } diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index 42e646fa5f..c400b6fcdd 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -432,6 +432,7 @@ def __init__( # Use the rope type specified in the arguments if provided, otherwise fall back to the one in the config. self.rope_type = (rope_type or self.config.rope_type).lower() + self.is_qwen2 = self.config.decoder_block == DecoderBlockType.QWEN2 self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT # Module attribute names must match names previously passed to Linen for checkpointing @@ -715,7 +716,7 @@ def init_out_w(self, output_dim: int) -> nnx.Module: quant=self.quant, shard_mode=self.config.shard_mode, matmul_precision=self.config.matmul_precision, - use_bias=self.use_bias_in_projections, + use_bias=False if self.is_qwen2 else self.use_bias_in_projections, rngs=self.rngs, ) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index d82bc065ca..af4cd0f842 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -51,6 +51,7 @@ llama4, mistral, mixtral, + qwen2, qwen3, simple_layer, olmo3, @@ -420,6 +421,8 @@ def get_decoder_layers(self): return [gpt3.Gpt3DecoderLayerToLinen] case DecoderBlockType.GPT_OSS: return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen] + case DecoderBlockType.QWEN2: + return [qwen2.Qwen2DecoderLayerToLinen] case DecoderBlockType.QWEN3: return [qwen3.Qwen3DecoderLayerToLinen] case DecoderBlockType.QWEN3_MOE: @@ -478,6 +481,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, + DecoderBlockType.QWEN2, DecoderBlockType.QWEN3, DecoderBlockType.QWEN3_MOE, DecoderBlockType.GPT_OSS, diff --git a/src/MaxText/layers/qwen2.py b/src/MaxText/layers/qwen2.py new file mode 100644 index 0000000000..b66ef448fb --- /dev/null +++ b/src/MaxText/layers/qwen2.py @@ -0,0 +1,219 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Qwen2 family of model decoder layers.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +from typing import Any + +from jax.ad_checkpoint import checkpoint_name +from jax.sharding import Mesh +import jax.numpy as jnp + +from flax import linen as nn +from flax import nnx + +from MaxText.common_types import Config +from MaxText.layers import initializers as max_initializers +from MaxText.layers import nnx_wrappers +from MaxText.layers import quantizations +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.layers.attentions import Attention +from MaxText.layers.linears import MlpBlock +from maxtext.inference import page_manager +from maxtext.utils import max_utils + + +# ----------------------------------------- +# The Base Decoder Layer for Qwen2 +# ----------------------------------------- +class AttentionWithNorm(nnx.Module): + """Base class with shared common components: self-attention block with normalization.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + + # Corresponds to Qwen2's `input_layernorm` + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + # Self-attention block + query_pre_attn_scalar = config.head_dim**-0.5 # Qwen2 specific scaling + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + use_qk_norm=config.use_qk_norm, + use_bias_in_projections=config.attention_bias, + query_pre_attn_scalar=query_pre_attn_scalar, + model_mode=model_mode, + use_mrope=config.use_mrope, + mrope_section=config.mrope_section, + rngs=rngs, + ) + + # Post Attention LayerNorm (corresponds to Qwen2's `post_attention_layernorm`) + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + def apply_attention_with_norm( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, + ): + """Applies self-attention with pre and post-layer normalization.""" + inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + # Pre attention norm + lnx = self.pre_self_attention_layer_norm(inputs) + lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) + # Self attention + attention_lnx, kv_cache = self.self_attention( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) + # Residual connection after attention + intermediate_inputs = inputs + attention_lnx + # Post attention norm + hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) + hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + return hidden_states, intermediate_inputs, kv_cache + + +# ----------------------------------------- +# The Dense Decoder Layer for Qwen2 +# ----------------------------------------- +class Qwen2DecoderLayer(AttentionWithNorm): + """Qwen2 Transformer decoder layer (dense).""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant, + rngs: nnx.Rngs, + ): + super().__init__(config, mesh, model_mode, quant, rngs) + self.mlp = MlpBlock( + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + config=config, + mesh=mesh, + quant=quant, + model_mode=model_mode, + rngs=rngs, + ) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, + ): + # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) + if isinstance(inputs, tuple): + inputs = inputs[0] + hidden_states, intermediate_inputs, kv_cache = self.apply_attention_with_norm( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + mlp_lnx = self.mlp(hidden_states, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) + + layer_output = intermediate_inputs + mlp_lnx + layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) + + if self.config.scan_layers: + return layer_output, None + else: + return layer_output, kv_cache + + +Qwen2DecoderLayerToLinen = nnx_wrappers.to_linen_class( + Qwen2DecoderLayer, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index c60c8f8cc9..579cae78bc 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -91,9 +91,9 @@ def _load_config(config_name: str) -> omegaconf.DictConfig: base_path = cfg[_BASE_CONFIG_ATTR] if not os.path.isabs(base_path): # Search relative to current config, then in the default configs folder - loaded_parent_config_filename = os.path.join(os.path.dirname(config_name), base_path) + loaded_parent_config_filename = resolve_config_path(os.path.join(os.path.dirname(config_name), base_path)) if not os.path.isfile(loaded_parent_config_filename): - loaded_parent_config_filename = os.path.join(MAXTEXT_CONFIGS_DIR, base_path) + loaded_parent_config_filename = resolve_config_path(os.path.join(MAXTEXT_CONFIGS_DIR, base_path)) else: loaded_parent_config_filename = base_path @@ -235,8 +235,7 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: model_config_path = os.path.join(os.path.dirname(config_path), "models", f"{model_name}.yml") if not os.path.isfile(model_config_path): # Fallback to default location within package - dir_path = os.path.dirname(os.path.realpath(__file__)) - model_config_path = os.path.join(dir_path, "configs", "models", f"{model_name}.yml") + model_config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "models", f"{model_name}.yml") if os.path.exists(model_config_path): model_loaded_cfg = omegaconf.OmegaConf.load(model_config_path) diff --git a/src/MaxText/pyconfig_deprecated.py b/src/MaxText/pyconfig_deprecated.py index 582d7a122f..34b53d1579 100644 --- a/src/MaxText/pyconfig_deprecated.py +++ b/src/MaxText/pyconfig_deprecated.py @@ -460,6 +460,8 @@ def validate_model_name(s: str) -> bool: "gemma3-4b", "gemma3-12b", "gemma3-27b", + "qwen2.5-7b", + "qwen2.5-14b", "qwen3-0.6b", "qwen3-4b", "qwen3-4b-thinking-2507", diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py index d91b7987ca..ddec2e336e 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py @@ -210,6 +210,41 @@ query_pre_attn_scalar=144, ) +qwen25_7b_config = transformers.Qwen2Config( + vocab_size=152064, + hidden_size=3584, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-06, + use_cache=True, + rope_theta=1000000.0, + tie_word_embeddings=False, + torch_dtype="bfloat16", + attention_bias=True, +) + +qwen25_14b_config = transformers.Qwen2Config( + vocab_size=152064, + hidden_size=5120, + intermediate_size=13824, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + rms_norm_eps=1e-06, + rope_theta=1000000.0, + tie_word_embeddings=False, + torch_dtype="bfloat16", + attention_bias=True, +) + + qwen3_0_6b_config = transformers.Qwen3Config( vocab_size=151936, hidden_size=1024, @@ -772,6 +807,8 @@ "gemma3-4b": gemma3_4b_config, "gemma3-12b": gemma3_12b_config, "gemma3-27b": gemma3_27b_config, + "qwen2.5-7b": qwen25_7b_config, + "qwen2.5-14b": qwen25_14b_config, "qwen3-0.6b": qwen3_0_6b_config, "qwen3-4b": qwen3_4b_config, "qwen3-4b-thinking-2507": qwen3_4b_config, diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py index 081017dd96..74fa1c4ed6 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py @@ -433,8 +433,8 @@ def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config): return mapping -def QWEN3_HF_WEIGHTS_TO_SHAPE(config): - """Returns mapping between HuggingFace Qwen3 weights path and the HuggingFace weights shape. +def QWEN_HF_WEIGHTS_TO_SHAPE(config): + """Returns mapping between HuggingFace Qwen weights path and the HuggingFace weights shape. To check this mapping, dump the huggingface model shapes: from transformers import AutoModelForCausalLM @@ -459,6 +459,7 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE(config): head_dim = config.get( "head_dim", config["hidden_size"] // config["num_attention_heads"] ) # head_dim might not always be present + attention_bias = config.get("attention_bias", False) mapping = { "model.embed_tokens.weight": [config["vocab_size"], hidden_size], @@ -484,6 +485,15 @@ def QWEN3_HF_WEIGHTS_TO_SHAPE(config): f"{layer_prefix}.self_attn.k_norm.weight": [head_dim], } + if attention_bias: + layer_mapping.update( + { + f"{layer_prefix}.self_attn.q_proj.bias": [num_attention_heads * head_dim], + f"{layer_prefix}.self_attn.k_proj.bias": [num_key_value_heads * head_dim], + f"{layer_prefix}.self_attn.v_proj.bias": [num_key_value_heads * head_dim], + } + ) + if num_experts > 1: # MoE MLP layers moe_ffn_intermediate_size = config.get("moe_intermediate_size") @@ -660,18 +670,20 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): "gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE, - "qwen3-0.6b": QWEN3_HF_WEIGHTS_TO_SHAPE, - "qwen3-4b": QWEN3_HF_WEIGHTS_TO_SHAPE, - "qwen3-4b-thinking-2507": QWEN3_HF_WEIGHTS_TO_SHAPE, - "qwen3-8b": QWEN3_HF_WEIGHTS_TO_SHAPE, - "qwen3-14b": QWEN3_HF_WEIGHTS_TO_SHAPE, - "qwen3-32b": QWEN3_HF_WEIGHTS_TO_SHAPE, + "qwen2.5-7b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen2.5-14b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-0.6b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-4b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-4b-thinking-2507": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-8b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-14b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-32b": QWEN_HF_WEIGHTS_TO_SHAPE, "llama3.1-8b": LLAMA31_HF_WEIGHTS_TO_SHAPE, "llama3.1-70b": LLAMA31_HF_WEIGHTS_TO_SHAPE, "llama3.1-405b": LLAMA31_HF_WEIGHTS_TO_SHAPE, - "qwen3-30b-a3b": QWEN3_HF_WEIGHTS_TO_SHAPE, - "qwen3-235b-a22b": QWEN3_HF_WEIGHTS_TO_SHAPE, - "qwen3-480b-a35b": QWEN3_HF_WEIGHTS_TO_SHAPE, + "qwen3-30b-a3b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-235b-a22b": QWEN_HF_WEIGHTS_TO_SHAPE, + "qwen3-480b-a35b": QWEN_HF_WEIGHTS_TO_SHAPE, "deepseek3-671b": DEEPSEEK_HF_WEIGHTS_TO_SHAPE, "gpt-oss-20b": GPT_OSS_HF_WEIGHTS_TO_SHAPE, "gpt-oss-120b": GPT_OSS_HF_WEIGHTS_TO_SHAPE, diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index d4f7317969..5b270e4401 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -573,11 +573,11 @@ def scale_query_layer(input_tensor, target_shape): return mapping -def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): - """Returns mapping from MaxText to HuggingFace Qwen3 weight paths. +def QWEN_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): + """Returns mapping from MaxText to HuggingFace Qwen weight paths. This function generates a dictionary that maps parameter names from a MaxText - Qwen3 checkpoint to their corresponding names in the Hugging Face format. + Qwen checkpoint to their corresponding names in the Hugging Face format. It handles both dense and Mixture-of-Experts (MoE) model variants. Args: @@ -617,6 +617,15 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) "params-decoder-layers-self_attention-value-kernel": [ f"model.layers.{i}.self_attn.v_proj.weight" for i in range(n_layers) ], + "params-decoder-layers-self_attention-query-bias": [ + f"model.layers.{i}.self_attn.q_proj.bias" for i in range(n_layers) + ], + "params-decoder-layers-self_attention-key-bias": [ + f"model.layers.{i}.self_attn.k_proj.bias" for i in range(n_layers) + ], + "params-decoder-layers-self_attention-value-bias": [ + f"model.layers.{i}.self_attn.v_proj.bias" for i in range(n_layers) + ], "params-decoder-layers-self_attention-out-kernel": [ f"model.layers.{i}.self_attn.o_proj.weight" for i in range(n_layers) ], @@ -674,6 +683,9 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) f"params-decoder-layers_{i}-self_attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", f"params-decoder-layers_{i}-self_attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", f"params-decoder-layers_{i}-self_attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", + f"params-decoder-layers_{i}-self_attention-query-bias": f"model.layers.{i}.self_attn.q_proj.bias", + f"params-decoder-layers_{i}-self_attention-key-bias": f"model.layers.{i}.self_attn.k_proj.bias", + f"params-decoder-layers_{i}-self_attention-value-bias": f"model.layers.{i}.self_attn.v_proj.bias", f"params-decoder-layers_{i}-self_attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight", f"params-decoder-layers_{i}-self_attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight", f"params-decoder-layers_{i}-post_self_attention_layer_norm-scale": f"model.layers.{i}.post_attention_layernorm.weight", @@ -707,8 +719,8 @@ def QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False) return mapping -def QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): - """Creates parameter transformation functions for Qwen3. +def QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): + """Creates parameter transformation functions for Qwen. This function provides a dictionary of transformation functions (hooks) for converting Qwen3 model parameters between MaxText and Hugging Face formats. @@ -752,6 +764,15 @@ def reshape_kernel(input_tensor, target_shape): else: return input_tensor.T.reshape(target_shape) + def reshape_bias(input_tensor, target_shape=None): + """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden).""" + if saving_to_hf: + # MaxText [heads, head_dim] -> HF [hidden_dim] (flatten) + return input_tensor.reshape(target_shape) + else: + # HF [hidden_dim] -> MaxText [heads, head_dim] + return input_tensor.reshape(target_shape) + mapping = { "params-token_embedder-embedding": pad_embedding_layer, "params-decoder-logits_dense-kernel": reshape_kernel, @@ -766,6 +787,11 @@ def reshape_kernel(input_tensor, target_shape): "mlp-wi_1-kernel", "mlp-wo-kernel", ] + bias_hooks = [ + "self_attention-query-bias", + "self_attention-key-bias", + "self_attention-value-bias", + ] moe_kernel_hooks = [ "moe_block-gate-kernel", "moe_block-wi_0-kernel", @@ -779,6 +805,8 @@ def reshape_kernel(input_tensor, target_shape): if scan_layers: for key in kernel_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_kernel + for key in bias_hooks: + mapping[f"params-decoder-layers-{key}"] = reshape_bias if num_experts > 1: for key in moe_kernel_hooks: mapping[f"params-decoder-layers-{key}"] = reshape_kernel @@ -786,6 +814,8 @@ def reshape_kernel(input_tensor, target_shape): for i in range(n_layers): for key in kernel_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel + for key in bias_hooks: + mapping[f"params-decoder-layers_{i}-{key}"] = reshape_bias if num_experts > 1: for key in moe_kernel_hooks: mapping[f"params-decoder-layers_{i}-{key}"] = reshape_kernel @@ -1126,7 +1156,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_laye # Text mapping with "thinker." prefix, reusing QWEN3-MOE mapping function num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0) n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"] - text_mapping = QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING( + text_mapping = QWEN_MAXTEXT_TO_HF_PARAM_MAPPING( config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text}, maxtext_config=maxtext_config, scan_layers=scan_layers, @@ -1294,7 +1324,7 @@ def QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_laye # Text hooks, reusing QWEN3-MOE hook function num_experts_text = config["thinker_config"]["text_config"].get("num_experts", 0) n_layers_text = config["thinker_config"]["text_config"]["num_hidden_layers"] - text_hooks = QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN( + text_hooks = QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN( config={"num_hidden_layers": n_layers_text, "num_experts": num_experts_text}, maxtext_config=maxtext_config, scan_layers=scan_layers, @@ -2082,18 +2112,23 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-0.5b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, - "qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -2113,18 +2148,23 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gemma3-4b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-12b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gemma3-27b": GEMMA3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-0.6b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-4b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-4b-thinking-2507": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-8b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-14b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-32b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-0.5b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-1.5b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-7b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen2.5-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-0.6b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-4b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-4b-thinking-2507": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-8b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-14b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-32b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-8b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-70b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, "llama3.1-405b": LLAMA31_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-30b-a3b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-235b-a22b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, - "qwen3-coder-480b-a35b": QWEN3_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-30b-a3b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN, diff --git a/src/MaxText/utils/ckpt_conversion/utils/utils.py b/src/MaxText/utils/ckpt_conversion/utils/utils.py index 5a0ecfe940..5cce69c2cf 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/utils.py +++ b/src/MaxText/utils/ckpt_conversion/utils/utils.py @@ -64,6 +64,8 @@ "gemma3-4b": "google/gemma-3-4b-it", # hf multi-modal should also support the pure-text "gemma3-12b": "google/gemma-3-12b-it", "gemma3-27b": "google/gemma-3-27b-it", + "qwen2.5-7b": "Qwen/Qwen2.5-7B-Instruct", + "qwen2.5-14b": "Qwen/Qwen2.5-14B-Instruct", "qwen3-0.6b": "Qwen/Qwen3-0.6B", "qwen3-4b": "Qwen/Qwen3-4B", "qwen3-4b-thinking-2507": "Qwen/Qwen3-4B-Thinking-2507", diff --git a/src/maxtext/configs/models/qwen2.5-14b.yml b/src/maxtext/configs/models/qwen2.5-14b.yml new file mode 100644 index 0000000000..92392d1ad7 --- /dev/null +++ b/src/maxtext/configs/models/qwen2.5-14b.yml @@ -0,0 +1,38 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for qwen2.5-14b + +base_emb_dim: 5120 +base_num_query_heads: 40 +base_num_kv_heads: 8 +base_mlp_dim: 13824 +base_num_decoder_layers: 48 +head_dim: 128 +mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU +vocab_size: 152064 + +decoder_block: "qwen2" + +normalization_layer_epsilon: 1.0e-6 +rope_max_timescale: 1000000 + +use_qk_norm: False +attention_bias: True + +logits_via_embedding: False +normalize_embedding_logits: False + +tokenizer_type: "huggingface" + diff --git a/src/maxtext/configs/models/qwen2.5-7b.yml b/src/maxtext/configs/models/qwen2.5-7b.yml new file mode 100644 index 0000000000..0876baf721 --- /dev/null +++ b/src/maxtext/configs/models/qwen2.5-7b.yml @@ -0,0 +1,33 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for qwen2.5-7b + +base_emb_dim: 3584 +base_num_query_heads: 28 +base_num_kv_heads: 4 +base_mlp_dim: 18944 +base_num_decoder_layers: 28 +head_dim: 128 +mlp_activations: ["silu", "linear"] +vocab_size: 152064 +decoder_block: "qwen2" +normalization_layer_epsilon: 1e-06 +rope_max_timescale: 1000000.0 +use_qk_norm: False +attention_bias: True +logits_via_embedding: False +normalize_embedding_logits: False +tokenizer_type: "huggingface" + diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 6457b49041..321b7389c9 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -228,6 +228,8 @@ class ProfilerType(str, Enum): "gemma3-4b", "gemma3-12b", "gemma3-27b", + "qwen2.5-7b", + "qwen2.5-14b", "qwen3-0.6b", "qwen3-4b", "qwen3-4b-thinking-2507", @@ -248,7 +250,7 @@ class ProfilerType(str, Enum): "llama4-17b-16e", "llama4-17b-128e", "olmo3-7b", - 'olmo3-7b-pt', + "olmo3-7b-pt", "olmo3-32b", ] diff --git a/tests/end_to_end/tpu/qwen/dense/qwen2.5-14b/test_qwen2.5-14b.sh b/tests/end_to_end/tpu/qwen/dense/qwen2.5-14b/test_qwen2.5-14b.sh new file mode 100644 index 0000000000..3cc2e58eb8 --- /dev/null +++ b/tests/end_to_end/tpu/qwen/dense/qwen2.5-14b/test_qwen2.5-14b.sh @@ -0,0 +1,108 @@ +#!/bin/bash + +# This script runs end-to-end tests for qwen2.5-14b on MaxText. +# The flow of this file is as follows: +# 1. Convert the HuggingFace checkpoint to MaxText-compatible checkpoint (scanned and unscanned). +# 2. Run logit check against the HuggingFace model. +# 3. Run SFT. + +# Example Usage: export BASE_OUTPUT_PATH= bash test_qwen2.5-14b.sh + +set -ex + +export MODEL_NAME='qwen2.5-14b' +export HF_MODEL_ID='Qwen/Qwen2.5-14B-Instruct' +export TOKENIZER_PATH=${HF_MODEL_ID} + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" +fi + +export BASE_OUTPUT_PATH_PREFIX="${BASE_OUTPUT_PATH}/${MODEL_NAME}/$(date +%Y-%m-%d-%H-%M)" + +# Installing torch for deps in forward_pass_logit_checker.py +#python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Step 1: Checkpoint conversion +echo "--- Starting Checkpoint Conversion ---" + +# 1.1 Convert checkpoint to `scanned` format +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.to_maxtext \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + model_name=${MODEL_NAME} \ + base_output_directory=${BASE_OUTPUT_PATH_PREFIX}/scanned \ + run_name=scanned_conversion \ + tokenizer_path=${TOKENIZER_PATH} \ + async_checkpointing=false \ + scan_layers=true + +export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH_PREFIX}/scanned/0/items + +# 1.2 Convert checkpoint to `unscanned` format +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.to_maxtext \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + model_name=${MODEL_NAME} \ + base_output_directory=${BASE_OUTPUT_PATH_PREFIX}/unscanned \ + run_name=unscanned_conversion \ + tokenizer_path=${TOKENIZER_PATH} \ + async_checkpointing=false \ + scan_layers=false + +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH_PREFIX}/unscanned/0/items + +# Step 2: Forward pass logit checker +echo "--- Starting Forward Pass Logit Checker ---" +# 2.1 Check unscanned checkpoint +JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + run_name=forward_pass_test_unscanned \ + model_name=${MODEL_NAME} \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + max_prefill_predict_length=4 \ + max_target_length=4 \ + dataset_type=synthetic \ + scan_layers=false \ + per_device_batch_size=1 \ + skip_jax_distributed_system=True \ + weight_dtype=bfloat16 \ + --max_kl_div=0.015 \ + --run_hf_model=True \ + --hf_model_path=${HF_MODEL_ID} + +# 2.2 Check scanned checkpoint +JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + run_name=forward_pass_test_scanned \ + model_name=${MODEL_NAME} \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + max_prefill_predict_length=4 \ + max_target_length=4 \ + dataset_type=synthetic \ + scan_layers=true \ + per_device_batch_size=1 \ + skip_jax_distributed_system=True \ + weight_dtype=bfloat16 \ + --max_kl_div=0.015 \ + --run_hf_model=True \ + --hf_model_path=${HF_MODEL_ID} + +# Step 3: SFT +echo "--- Starting SFT ---" +python3 -m python3 -m maxtext.trainers.post_train.sft.train_sft \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/post_train/sft.yml \ + base_output_directory=${BASE_OUTPUT_PATH_PREFIX}/finetuned \ + run_name=sft_test \ + model_name=${MODEL_NAME} \ + tokenizer_path=${TOKENIZER_PATH} \ + tokenizer_type=huggingface \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + dataset_type=hf \ + scan_layers=true \ + per_device_batch_size=4 \ + learning_rate=1.3e-5 \ + steps=5 \ + max_target_length=1024 \ + weight_dtype=bfloat16 diff --git a/tests/end_to_end/tpu/qwen/dense/qwen2.5-7b/test_qwen2.5-7b.sh b/tests/end_to_end/tpu/qwen/dense/qwen2.5-7b/test_qwen2.5-7b.sh new file mode 100644 index 0000000000..bebea72021 --- /dev/null +++ b/tests/end_to_end/tpu/qwen/dense/qwen2.5-7b/test_qwen2.5-7b.sh @@ -0,0 +1,108 @@ +#!/bin/bash + +# This script runs end-to-end tests for qwen2.5-7b on MaxText. +# The flow of this file is as follows: +# 1. Convert the HuggingFace checkpoint to MaxText-compatible checkpoint (scanned and unscanned). +# 2. Run logit check against the HuggingFace model. +# 3. Run SFT. + +# Example Usage: export BASE_OUTPUT_PATH= bash test_qwen2.5-7b.sh + +set -ex + +export MODEL_NAME='qwen2.5-7b' +export HF_MODEL_ID='Qwen/Qwen2.5-7B-Instruct' +export TOKENIZER_PATH=${HF_MODEL_ID} + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set, using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}" +fi + +export BASE_OUTPUT_PATH_PREFIX="${BASE_OUTPUT_PATH}/${MODEL_NAME}/$(date +%Y-%m-%d-%H-%M)" + +# Installing torch for deps in forward_pass_logit_checker.py +#python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Step 1: Checkpoint conversion +echo "--- Starting Checkpoint Conversion ---" + +# 1.1 Convert checkpoint to `scanned` format +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.to_maxtext \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + model_name=${MODEL_NAME} \ + base_output_directory=${BASE_OUTPUT_PATH_PREFIX}/scanned \ + run_name=scanned_conversion \ + tokenizer_path=${TOKENIZER_PATH} \ + async_checkpointing=false \ + scan_layers=true + +export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH_PREFIX}/scanned/0/items + +# 1.2 Convert checkpoint to `unscanned` format +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.to_maxtext \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + model_name=${MODEL_NAME} \ + base_output_directory=${BASE_OUTPUT_PATH_PREFIX}/unscanned \ + run_name=unscanned_conversion \ + tokenizer_path=${TOKENIZER_PATH} \ + async_checkpointing=false \ + scan_layers=false + +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH_PREFIX}/unscanned/0/items + +# Step 2: Forward pass logit checker +echo "--- Starting Forward Pass Logit Checker ---" +# 2.1 Check unscanned checkpoint +JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + run_name=forward_pass_test_unscanned \ + model_name=${MODEL_NAME} \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${UNSCANNED_CKPT_PATH} \ + max_prefill_predict_length=4 \ + max_target_length=4 \ + dataset_type=synthetic \ + scan_layers=false \ + per_device_batch_size=1 \ + skip_jax_distributed_system=True \ + weight_dtype=bfloat16 \ + --max_kl_div=0.017 \ + --run_hf_model=True \ + --hf_model_path=${HF_MODEL_ID} + +# 2.2 Check scanned checkpoint +JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + run_name=forward_pass_test_scanned \ + model_name=${MODEL_NAME} \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + max_prefill_predict_length=4 \ + max_target_length=4 \ + dataset_type=synthetic \ + scan_layers=true \ + per_device_batch_size=1 \ + skip_jax_distributed_system=True \ + weight_dtype=bfloat16 \ + --max_kl_div=0.017 \ + --run_hf_model=True \ + --hf_model_path=${HF_MODEL_ID} + +# Step 3: SFT +echo "--- Starting SFT ---" +python3 -m python3 -m maxtext.trainers.post_train.sft.train_sft \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/post_train/sft.yml \ + base_output_directory=${BASE_OUTPUT_PATH_PREFIX}/finetuned \ + run_name=sft_test \ + model_name=${MODEL_NAME} \ + tokenizer_path=${TOKENIZER_PATH} \ + tokenizer_type=huggingface \ + load_parameters_path=${SCANNED_CKPT_PATH} \ + dataset_type=hf \ + scan_layers=true \ + per_device_batch_size=4 \ + learning_rate=1.3e-5 \ + steps=5 \ + max_target_length=1024 \ + weight_dtype=bfloat16 diff --git a/tests/end_to_end/tpu/qwen/dense/run_qwen2.5_dense.md b/tests/end_to_end/tpu/qwen/dense/run_qwen2.5_dense.md new file mode 100644 index 0000000000..e26106e199 --- /dev/null +++ b/tests/end_to_end/tpu/qwen/dense/run_qwen2.5_dense.md @@ -0,0 +1,44 @@ +# Qwen2.5 Dense + +Qwen2.5 is the latest series of large language models by Qwen, released in September 2024. The models use a dense +transformer architecture. You can find more information in the [blog](https://qwenlm.github.io/blog/qwen2.5/) and +the [model card](https://huggingface.co/Qwen/Qwen2.5-7B). The currently supported models +are [qwen2.5-7b](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) +and [qwen2.5-14b](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct). + +## Running the End-to-End Test + +The `test_qwen2.5-14b.sh` script automates the following steps: + +1. **Checkpoint Conversion**: Converts the Hugging Face checkpoint to MaxText-compatible format (scanned and unscanned). +1. **Logit Check**: Verifies the forward pass logits against the Hugging Face model. +1. **SFT**: Runs a Supervised Fine-Tuning (SFT) job. + +### Prerequisites + +- Ensure you have write access to a GCS bucket for output logs and checkpoints. + +### Usage + +To run the test for Qwen 2.5 14B: + +```bash +export BASE_OUTPUT_PATH=gs://your-gcs-bucket/ +bash tests/end_to_end/tpu/qwen/dense/qwen2.5-14b/test_qwen2.5-14b.sh +``` + +This will: + +- Download the `Qwen/Qwen2.5-14B-Instruct` model. +- Convert it to MaxText format. +- Run validation and training tests. +- Store artifacts in `${BASE_OUTPUT_PATH}/qwen2.5-14b/`. + +### Qwen 2.5 7B + +Similarly, for Qwen 2.5 7B: + +```bash +export BASE_OUTPUT_PATH=gs://your-gcs-bucket/ +bash tests/end_to_end/tpu/qwen/dense/qwen2.5-7b/test_qwen2.5-7b.sh +```