Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ MaxText aims to provide you with the best OSS models, whether as a reference imp
* Gemma 2 (2B, 9B, 27B)
* Gemma 1 (2B, 7B)
* Alibaba
* Qwen 2.5 (7B, 14B)
* Qwen 3 MoE 2507 (235B, 480B)
* Qwen 3 MoE (30B, 235B)
* Qwen 3 Dense (0.6B, 1.7B, 4B, 8B, 14B, 32B)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The following models are supported:
| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ |
| **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ |
| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ |
| **Qwen2.5** | 7B, 14B | √ | √ | √ | √ |
| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ |
| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ |
| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ |
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class DecoderBlockType(enum.Enum):
GEMMA = "gemma"
GEMMA2 = "gemma2"
GEMMA3 = "gemma3"
QWEN2 = "qwen2"
QWEN3 = "qwen3"
QWEN3_MOE = "qwen3_moe"
QWEN3_NEXT = "qwen3_next"
Expand Down
3 changes: 3 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from MaxText.integration.tunix.weight_mapping.deepseek3 import DEEPSEEK_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.gpt_oss import GPT_OSS_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.llama3 import LLAMA3_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.qwen2 import QWEN2_VLLM_MAPPING
from MaxText.integration.tunix.weight_mapping.qwen3 import QWEN3_VLLM_MAPPING


Expand All @@ -30,6 +31,8 @@ class StandaloneVllmWeightMapping:
def __getattr__(self, name):
if name.startswith("llama3.1"):
return LLAMA3_VLLM_MAPPING
elif name.startswith("qwen2"):
return QWEN2_VLLM_MAPPING
elif name.startswith("qwen3"):
return QWEN3_VLLM_MAPPING
elif name.startswith("deepseek3"):
Expand Down
136 changes: 136 additions & 0 deletions src/MaxText/integration/tunix/weight_mapping/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines the weight mapping from MaxText's Qwen2 model to a vLLM-compatible format.

This module provides the `QWEN2_VLLM_MAPPING` dataclass, which contains all the
necessary configurations to convert MaxText's Qwen2 model weights into a
format that can be loaded by HuggingFace's vLLM. This includes:
- A direct mapping of parameter names.
- Sharding specifications for distributed environments.
"""

from dataclasses import dataclass


@dataclass
class QWEN2_VLLM_MAPPING:
"""Mapping MaxText Qwen2 weights to vLLM's Qwen2 weights."""

@staticmethod
def to_hf_hook_fns():
"""Returns a dictionary of hook functions to be applied to MaxText weights.

Returns:
An empty dictionary, as no hook functions are needed for this mapping.
"""

return {}

@staticmethod
def to_hf_transpose_keys():
"""Returns a list of keys for weights that need to be transposed.

Returns:
An empty dictionary, as no keys require transposition for this mapping.
"""
return {}

@staticmethod
def lora_to_hf_mappings():
"""Provides the mapping for LoRA (Low-Rank Adaptation) weights.

Returns:
None, as LoRA mappings are not defined for this model.
"""
return None

@staticmethod
def to_hf_mapping():
"""Mapping from MaxText model to HuggingFace vLLM model.

Currently, the param mapping conforms to the Tunix API, which combines the
param name & sharding in one dictionary.
This is subject to change in the future where we can decouple the two.
"""
return {
# Token embeddings - shard vocab dimension
"base.token_embedder.embedding": (
"model.embed.embedding",
("model", None),
),
# Final layer norm - no sharding needed
"base.decoder.decoder_norm.scale": (
"model.norm.scale",
(None,),
),
# LM head (logits projection) - shard vocab dimension
"base.decoder.logits_dense.kernel": (
"model.lm_head",
(None, "model"),
),
# Layer-specific mappings (scanned -> unscanned)
# MLP components - shard hidden dimensions
"base.decoder.layers.mlp.wi_0.kernel": (
"model.layers.*.mlp.gate_proj.kernel",
(None, "layer", "model"),
),
"base.decoder.layers.mlp.wi_1.kernel": (
"model.layers.*.mlp.up_proj.kernel",
(None, "layer", "model"),
),
"base.decoder.layers.mlp.wo.kernel": (
"model.layers.*.mlp.down_proj.kernel",
("model", "layer", None),
),
# Layer norms - no sharding needed
"base.decoder.layers.pre_self_attention_layer_norm.scale": (
"model.layers.*.input_layernorm.scale",
(None, "layer"),
),
"base.decoder.layers.post_self_attention_layer_norm.scale": (
"model.layers.*.post_attention_layernorm.scale",
(None, "layer"),
),
# Attention components - shard head dimensions
"base.decoder.layers.self_attention.query.kernel": (
"model.layers.*.self_attn.q_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.key.kernel": (
"model.layers.*.self_attn.k_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.value.kernel": (
"model.layers.*.self_attn.v_proj.kernel",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.out.kernel": (
"model.layers.*.self_attn.o_proj.kernel",
("model", "layer", None, None),
),
# Attention biases
"base.decoder.layers.self_attention.query.bias": (
"model.layers.*.self_attn.q_proj.bias",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.key.bias": (
"model.layers.*.self_attn.k_proj.bias",
(None, "layer", "model", None),
),
"base.decoder.layers.self_attention.value.bias": (
"model.layers.*.self_attn.v_proj.bias",
(None, "layer", "model", None),
),
}
3 changes: 2 additions & 1 deletion src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def __init__(
# Use the rope type specified in the arguments if provided, otherwise fall back to the one in the config.
self.rope_type = (rope_type or self.config.rope_type).lower()

self.is_qwen2 = self.config.decoder_block == DecoderBlockType.QWEN2
self.is_qwen3_next = self.config.decoder_block == DecoderBlockType.QWEN3_NEXT

# Module attribute names must match names previously passed to Linen for checkpointing
Expand Down Expand Up @@ -715,7 +716,7 @@ def init_out_w(self, output_dim: int) -> nnx.Module:
quant=self.quant,
shard_mode=self.config.shard_mode,
matmul_precision=self.config.matmul_precision,
use_bias=self.use_bias_in_projections,
use_bias=False if self.is_qwen2 else self.use_bias_in_projections,
rngs=self.rngs,
)

Expand Down
4 changes: 4 additions & 0 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
llama4,
mistral,
mixtral,
qwen2,
qwen3,
simple_layer,
olmo3,
Expand Down Expand Up @@ -420,6 +421,8 @@ def get_decoder_layers(self):
return [gpt3.Gpt3DecoderLayerToLinen]
case DecoderBlockType.GPT_OSS:
return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen]
case DecoderBlockType.QWEN2:
return [qwen2.Qwen2DecoderLayerToLinen]
case DecoderBlockType.QWEN3:
return [qwen3.Qwen3DecoderLayerToLinen]
case DecoderBlockType.QWEN3_MOE:
Expand Down Expand Up @@ -478,6 +481,7 @@ def get_norm_layer(self, num_features: int):
DecoderBlockType.GEMMA,
DecoderBlockType.GEMMA2,
DecoderBlockType.GEMMA3,
DecoderBlockType.QWEN2,
DecoderBlockType.QWEN3,
DecoderBlockType.QWEN3_MOE,
DecoderBlockType.GPT_OSS,
Expand Down
Loading
Loading