Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@ def __init__(
else:
self._tensor_parallelism_name = "tensor"

if self.config.attention == "vllm_rpa":
self._expert_parallelism_name = "attn_dp_expert"
else:
self._expert_parallelism_name = "expert"

self.gate = GateLogit(
in_features_shape=self.config.emb_dim,
out_features_shape=self.num_experts,
Expand Down Expand Up @@ -465,7 +470,12 @@ def _logical_to_mesh_axes(self, logical_name):
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)

def get_expert_parallelism_size(self):
return self.mesh.shape.get("expert", 1)
if isinstance(self._expert_parallelism_name, tuple):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the tuple case actually supported? Have you tested it?

size = 1
for axis in self._expert_parallelism_name:
size *= self.mesh.shape.get(axis, 1)
return size
return self.mesh.shape.get(self._expert_parallelism_name, 1)

def get_tensor_parallelism_size(self):
if isinstance(self._tensor_parallelism_name, tuple):
Expand Down Expand Up @@ -1000,7 +1010,7 @@ def gmm(
# batch_size=1 while decode can have batch_size > 1.
try:
is_batch_sharded_by_expert = (
"expert"
self._expert_parallelism_name
in tuple(
filter(
lambda tup: tup[0] == "activation_batch",
Expand Down Expand Up @@ -1092,10 +1102,9 @@ def gmm(
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
batch_size, sequence_length, _ = x.shape
expert_axis_name = "expert"
num_expert_parallelism = self.get_expert_parallelism_size()
if num_expert_parallelism > 1:
expert_shard_id = jax.lax.axis_index(expert_axis_name)
expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name)
else:
expert_shard_id = 0
num_expert_parallelism = self.get_expert_parallelism_size()
Expand All @@ -1105,7 +1114,8 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r

# Duplicate inputs to all expert shards.
x, logits, pre_bias_logits = tuple(
jax.lax.all_gather(z, axis_name=expert_axis_name, tiled=True) for z in (x, logits, pre_bias_logits)
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
for z in (x, logits, pre_bias_logits)
)

# "Route" tokens within each shard.
Expand All @@ -1129,7 +1139,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
)

if num_expert_parallelism > 1:
batch_axis = "expert" if is_batch_sharded_by_expert else "data"
batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data"
# get group sizes for all shards
local_expert_size = self.config.num_experts // num_expert_parallelism
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
Expand Down Expand Up @@ -1161,9 +1171,9 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
send_sizes,
output_offsets,
recv_sizes,
axis_name=expert_axis_name,
axis_name=self._expert_parallelism_name,
)
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=expert_axis_name)
global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name)
x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute(
x,
global_group_sizes,
Expand Down Expand Up @@ -1308,7 +1318,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):

# Sum up the partial outputs across the expert shards.
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim))
output = jax.lax.psum_scatter(output, expert_axis_name, scatter_dimension=0, tiled=True)
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)

else:
if num_expert_parallelism > 1:
Expand Down Expand Up @@ -1341,7 +1351,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
send_sizes,
output_offsets,
recv_sizes,
axis_name=expert_axis_name,
axis_name=self._expert_parallelism_name,
)
else:
# If bach is replicated across EP shards then each shard should send
Expand All @@ -1361,7 +1371,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
send_sizes,
output_offsets,
recv_sizes,
axis_name=expert_axis_name,
axis_name=self._expert_parallelism_name,
)

output = self.unpermute(
Expand Down
24 changes: 12 additions & 12 deletions src/MaxText/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices)

rollout_kwargs = {}
tp = sampler_config.rollout_tensor_parallelism
ep = sampler_config.rollout_expert_parallelism

if tp == -1:
if num_sampler_devices % dp != 0:
Expand All @@ -244,17 +245,22 @@ def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices)
f"rollout_data_parallelism({dp}) "
f"when rollout_tensor_parallelism is -1."
)
tp = num_sampler_devices // dp
elif tp * dp != num_sampler_devices:
tp = num_sampler_devices // dp // ep
elif tp * dp * ep != num_sampler_devices:
raise ValueError(
f"rollout_tensor_parallelism({tp}) * "
f"rollout_data_parallelism({dp}) "
f"rollout_data_parallelism({dp}) * "
f"rollout_expert_parallelism({ep}) "
f"!= len(sampler_devices)({num_sampler_devices})"
)
rollout_kwargs["tensor_parallel_size"] = tp
rollout_kwargs["data_parallel_size"] = dp
rollout_kwargs["rollout_vllm_async_scheduling"] = True

if ep > 1:
rollout_kwargs["expert_parallel_size"] = ep
rollout_kwargs["rollout_vllm_enable_expert_parallelism"] = True

rollout_kwargs["rollout_vllm_async_scheduling"] = True
return rollout_kwargs


Expand Down Expand Up @@ -321,10 +327,7 @@ def _filter_long_prompts(x):
train_dataset = train_dataset[:dataset_size]
train_dataset = train_dataset.repeat(trainer_config.num_epoch)

train_dataset = (
train_dataset.to_iter_dataset()
.batch(trainer_config.batch_size)
)
train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size)

eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None)
if not eval_dataset_name:
Expand All @@ -342,10 +345,7 @@ def _filter_long_prompts(x):
test_dataset = test_dataset.filter(_filter_long_prompts)
test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size]

test_dataset = (
test_dataset.to_iter_dataset()
.batch(trainer_config.batch_size)
)
test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size)

# Load reference model
max_logging.log("Creating reference model and also meshes for reference and rollout")
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,8 @@ use_jax_splash: false
# vLLM Adapter Configurations
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
vllm_hf_config_path: ""
# Path to yaml file for loading vLLM config
vllm_config_path: ""
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
vllm_additional_config: {}
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
Expand Down
23 changes: 12 additions & 11 deletions src/maxtext/configs/inference/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ weight_dtype: bfloat16


# -------------- Logical Axis Rules --------------
mesh_axes: ['data', 'attn_dp', 'model', 'expert']
mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']
logical_axis_rules: [
['activation_batch', ['expert']],
['activation_batch_no_exp', []],
Expand All @@ -37,37 +37,38 @@ logical_axis_rules: [
['activation_attn_length_no_exp', []],
['activation_length', ['data', 'expert']],
['activation_length_no_exp', 'data'],
['activation_q_length', ['expert']],
['activation_q_length', ['expert', 'attn_dp_expert']],
['activation_attn_embed', 'model'],
['activation_embed', ['model', 'attn_dp']],
['activation_mlp', ['model', 'attn_dp']],
['activation_kv', ['model']],
['activation_prefill_kv_batch', ['expert']],
['activation_kv_batch', ['expert']],
['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']],
['activation_kv_batch', ['expert', 'attn_dp_expert']],
['activation_kv_batch_no_exp', []],
['activation_kv_head_dim', ['model']],
['activation_vocab', ['model', 'attn_dp']],
['activation_norm_length', []],
['activation_exp', ['expert']],
['decode_batch', ['expert']],
['activation_exp', ['expert', 'attn_dp_expert']],
['decode_batch', ['expert', 'attn_dp_expert']],
['decode_length', []],
['mlp', ['model', 'attn_dp']],
['mlp_no_fsdp', ['model', 'attn_dp']],
['moe_mlp', ['model', 'attn_dp']],
['vocab', ['model', 'attn_dp']],
['heads', ['model']],
['q_heads', ['model']],
['kv_heads', ['model']],
['kv_head_dim', []],
['kv', []],
['embed', ['expert']],
['embed', ['expert', 'attn_dp_expert']],
['embed_tensor_transpose', ['attn_dp', 'model']],
['embed_no_exp', []],
['q_lora', ['expert']],
['kv_lora', ['expert']],
['q_lora', ['expert', 'attn_dp_expert']],
['kv_lora', ['expert', 'attn_dp_expert']],
['norm', []],
['cache_heads', ['model']],
['exp', ['expert']],
['exp', ['expert', 'attn_dp_expert']],
['paged_kv_heads', ['model']],
]
data_sharding: [['data', 'attn_dp', 'model', 'expert']]
data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch']
2 changes: 2 additions & 0 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ enable_dp_attention: False
# Performance tuning for samplers
max_num_batched_tokens: null
max_num_seqs: null
# path to initialize vllm config
vllm_config_path: 'src/maxtext/configs/inference/vllm.yml'

# ====== Checkpoint Configuration ======
enable_checkpointing: True
Expand Down
8 changes: 7 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class ProfilerType(str, Enum):
"llama4-17b-16e",
"llama4-17b-128e",
"olmo3-7b",
'olmo3-7b-pt',
"olmo3-7b-pt",
"olmo3-32b",
]

Expand Down Expand Up @@ -1544,6 +1544,7 @@ class RLHardware(BaseModel):
rollout_tensor_parallelism: int = Field(
-1, description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined."
)
rollout_expert_parallelism: int = Field(1, description="Expert parallelism per replica for rollout")


class VLLM(BaseModel):
Expand All @@ -1557,6 +1558,9 @@ class VLLM(BaseModel):
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
vllm_config_path: str = Field(
"src/maxtext/configs/inference/vllm.yml", description="path to yaml file for loading vLLM config."
)


class RL(BaseModel):
Expand Down Expand Up @@ -2447,6 +2451,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
"expert": self.ici_expert_parallelism,
"autoregressive": self.ici_autoregressive_parallelism,
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
}
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]

Expand All @@ -2466,6 +2471,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
"expert": self.dcn_expert_parallelism,
"autoregressive": self.dcn_autoregressive_parallelism,
"attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
"attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP
}
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]

Expand Down
19 changes: 16 additions & 3 deletions src/maxtext/vllm_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,21 @@
flags.DEFINE_integer("max_prefill_length", 512, "Maximum prefill length.")
flags.DEFINE_float("gpu_memory_utilization", 0.72, "Fraction of GPU memory to be used for the model executor.")

# vllm config variables
flags.DEFINE_integer("vllm_swap_space", 2, "per device swap space in GB")
flags.DEFINE_integer("vllm_async_scheduling", 1, "Async DP Scheduler for vLLM")

# Decoding
flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.")
flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.")
flags.DEFINE_integer("decode_sampling_temperature", 0, "Temperature for sampling.")
flags.DEFINE_integer("decode_sampling_nucleus_p", 1, "Nucleus sampling probability.")
flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.")

# Mark required flags
flags.mark_flag_as_required("hf_config_path")
flags.DEFINE_string(
"vllm_config_path",
"src/MaxText/configs/vllm.yml",
"Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml.",
)


def decode_with_vllm(
Expand All @@ -103,6 +109,8 @@ def decode_with_vllm(
max_prefill_length: int,
max_target_length: int,
gpu_memory_utilization: float,
vllm_swap_space: int,
vllm_async_scheduling: int,
enable_expert_parallel: bool,
prompt: str,
decode_sampling_temperature: float,
Expand Down Expand Up @@ -145,6 +153,8 @@ def decode_with_vllm(
vllm_args["enable_expert_parallel"] = enable_expert_parallel
vllm_args["hf_config_path"] = hf_config_path
vllm_args["gpu_memory_utilization"] = gpu_memory_utilization
vllm_args["swap_space"] = vllm_swap_space
vllm_args["async_scheduling"] = vllm_async_scheduling

# Prepare MaxText and sharding configs (Parallelism is dynamic)
vllm_args["additional_config"]["maxtext_config"] = {
Expand Down Expand Up @@ -291,12 +301,15 @@ def main(argv: Sequence[str]) -> None:
max_target_length=FLAGS.max_target_length,
max_prefill_length=FLAGS.max_prefill_length,
gpu_memory_utilization=FLAGS.gpu_memory_utilization,
vllm_swap_space=FLAGS.vllm_swap_space,
vllm_async_scheduling=FLAGS.vllm_async_scheduling,
enable_expert_parallel=FLAGS.enable_expert_parallel,
prompt=FLAGS.prompt,
decode_sampling_temperature=FLAGS.decode_sampling_temperature,
decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p,
decode_sampling_top_k=FLAGS.decode_sampling_top_k,
debug_sharding=FLAGS.debug_sharding,
vllm_config_path=FLAGS.vllm_config_path,
)


Expand Down
Loading