diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 7c54faf5a0..9b813e1070 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -368,6 +368,11 @@ def __init__( else: self._tensor_parallelism_name = "tensor" + if self.config.attention == "vllm_rpa": + self._expert_parallelism_name = "attn_dp_expert" + else: + self._expert_parallelism_name = "expert" + self.gate = GateLogit( in_features_shape=self.config.emb_dim, out_features_shape=self.num_experts, @@ -465,7 +470,12 @@ def _logical_to_mesh_axes(self, logical_name): return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules) def get_expert_parallelism_size(self): - return self.mesh.shape.get("expert", 1) + if isinstance(self._expert_parallelism_name, tuple): + size = 1 + for axis in self._expert_parallelism_name: + size *= self.mesh.shape.get(axis, 1) + return size + return self.mesh.shape.get(self._expert_parallelism_name, 1) def get_tensor_parallelism_size(self): if isinstance(self._tensor_parallelism_name, tuple): @@ -1000,7 +1010,7 @@ def gmm( # batch_size=1 while decode can have batch_size > 1. try: is_batch_sharded_by_expert = ( - "expert" + self._expert_parallelism_name in tuple( filter( lambda tup: tup[0] == "activation_batch", @@ -1092,10 +1102,9 @@ def gmm( ) def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs): batch_size, sequence_length, _ = x.shape - expert_axis_name = "expert" num_expert_parallelism = self.get_expert_parallelism_size() if num_expert_parallelism > 1: - expert_shard_id = jax.lax.axis_index(expert_axis_name) + expert_shard_id = jax.lax.axis_index(self._expert_parallelism_name) else: expert_shard_id = 0 num_expert_parallelism = self.get_expert_parallelism_size() @@ -1105,7 +1114,8 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r # Duplicate inputs to all expert shards. x, logits, pre_bias_logits = tuple( - jax.lax.all_gather(z, axis_name=expert_axis_name, tiled=True) for z in (x, logits, pre_bias_logits) + jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True) + for z in (x, logits, pre_bias_logits) ) # "Route" tokens within each shard. @@ -1129,7 +1139,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r ) if num_expert_parallelism > 1: - batch_axis = "expert" if is_batch_sharded_by_expert else "data" + batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data" # get group sizes for all shards local_expert_size = self.config.num_experts // num_expert_parallelism reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1) @@ -1161,9 +1171,9 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r send_sizes, output_offsets, recv_sizes, - axis_name=expert_axis_name, + axis_name=self._expert_parallelism_name, ) - global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=expert_axis_name) + global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( x, global_group_sizes, @@ -1308,7 +1318,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): # Sum up the partial outputs across the expert shards. output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim)) - output = jax.lax.psum_scatter(output, expert_axis_name, scatter_dimension=0, tiled=True) + output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True) else: if num_expert_parallelism > 1: @@ -1341,7 +1351,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): send_sizes, output_offsets, recv_sizes, - axis_name=expert_axis_name, + axis_name=self._expert_parallelism_name, ) else: # If bach is replicated across EP shards then each shard should send @@ -1361,7 +1371,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): send_sizes, output_offsets, recv_sizes, - axis_name=expert_axis_name, + axis_name=self._expert_parallelism_name, ) output = self.unpermute( diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 00752aa3c7..0b66a7763b 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -236,6 +236,7 @@ def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices) rollout_kwargs = {} tp = sampler_config.rollout_tensor_parallelism + ep = sampler_config.rollout_expert_parallelism if tp == -1: if num_sampler_devices % dp != 0: @@ -244,17 +245,22 @@ def get_rollout_kwargs_for_data_parallelism(sampler_config, num_sampler_devices) f"rollout_data_parallelism({dp}) " f"when rollout_tensor_parallelism is -1." ) - tp = num_sampler_devices // dp - elif tp * dp != num_sampler_devices: + tp = num_sampler_devices // dp // ep + elif tp * dp * ep != num_sampler_devices: raise ValueError( f"rollout_tensor_parallelism({tp}) * " - f"rollout_data_parallelism({dp}) " + f"rollout_data_parallelism({dp}) * " + f"rollout_expert_parallelism({ep}) " f"!= len(sampler_devices)({num_sampler_devices})" ) rollout_kwargs["tensor_parallel_size"] = tp rollout_kwargs["data_parallel_size"] = dp - rollout_kwargs["rollout_vllm_async_scheduling"] = True + if ep > 1: + rollout_kwargs["expert_parallel_size"] = ep + rollout_kwargs["rollout_vllm_enable_expert_parallelism"] = True + + rollout_kwargs["rollout_vllm_async_scheduling"] = True return rollout_kwargs @@ -321,10 +327,7 @@ def _filter_long_prompts(x): train_dataset = train_dataset[:dataset_size] train_dataset = train_dataset.repeat(trainer_config.num_epoch) - train_dataset = ( - train_dataset.to_iter_dataset() - .batch(trainer_config.batch_size) - ) + train_dataset = train_dataset.to_iter_dataset().batch(trainer_config.batch_size) eval_dataset_name = getattr(trainer_config, "eval_dataset_name", None) if not eval_dataset_name: @@ -342,10 +345,7 @@ def _filter_long_prompts(x): test_dataset = test_dataset.filter(_filter_long_prompts) test_dataset = test_dataset[: trainer_config.num_test_batches * trainer_config.batch_size] - test_dataset = ( - test_dataset.to_iter_dataset() - .batch(trainer_config.batch_size) - ) + test_dataset = test_dataset.to_iter_dataset().batch(trainer_config.batch_size) # Load reference model max_logging.log("Creating reference model and also meshes for reference and rollout") diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 632f3345b7..e71b241550 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1083,6 +1083,8 @@ use_jax_splash: false # vLLM Adapter Configurations # Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter) vllm_hf_config_path: "" +# Path to yaml file for loading vLLM config +vllm_config_path: "" # JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') vllm_additional_config: {} # When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH] diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 21ca47410e..30af0631ca 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -25,7 +25,7 @@ weight_dtype: bfloat16 # -------------- Logical Axis Rules -------------- -mesh_axes: ['data', 'attn_dp', 'model', 'expert'] +mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert'] logical_axis_rules: [ ['activation_batch', ['expert']], ['activation_batch_no_exp', []], @@ -37,37 +37,38 @@ logical_axis_rules: [ ['activation_attn_length_no_exp', []], ['activation_length', ['data', 'expert']], ['activation_length_no_exp', 'data'], - ['activation_q_length', ['expert']], + ['activation_q_length', ['expert', 'attn_dp_expert']], ['activation_attn_embed', 'model'], ['activation_embed', ['model', 'attn_dp']], ['activation_mlp', ['model', 'attn_dp']], ['activation_kv', ['model']], - ['activation_prefill_kv_batch', ['expert']], - ['activation_kv_batch', ['expert']], + ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']], + ['activation_kv_batch', ['expert', 'attn_dp_expert']], ['activation_kv_batch_no_exp', []], ['activation_kv_head_dim', ['model']], ['activation_vocab', ['model', 'attn_dp']], ['activation_norm_length', []], - ['activation_exp', ['expert']], - ['decode_batch', ['expert']], + ['activation_exp', ['expert', 'attn_dp_expert']], + ['decode_batch', ['expert', 'attn_dp_expert']], ['decode_length', []], ['mlp', ['model', 'attn_dp']], ['mlp_no_fsdp', ['model', 'attn_dp']], + ['moe_mlp', ['model', 'attn_dp']], ['vocab', ['model', 'attn_dp']], ['heads', ['model']], ['q_heads', ['model']], ['kv_heads', ['model']], ['kv_head_dim', []], ['kv', []], - ['embed', ['expert']], + ['embed', ['expert', 'attn_dp_expert']], ['embed_tensor_transpose', ['attn_dp', 'model']], ['embed_no_exp', []], - ['q_lora', ['expert']], - ['kv_lora', ['expert']], + ['q_lora', ['expert', 'attn_dp_expert']], + ['kv_lora', ['expert', 'attn_dp_expert']], ['norm', []], ['cache_heads', ['model']], - ['exp', ['expert']], + ['exp', ['expert', 'attn_dp_expert']], ['paged_kv_heads', ['model']], ] -data_sharding: [['data', 'attn_dp', 'model', 'expert']] +data_sharding: [['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert']] input_data_sharding_logical_axes: ['activation_embed_and_logits_batch'] diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 9d741e7a8c..a32fff2376 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -149,6 +149,8 @@ enable_dp_attention: False # Performance tuning for samplers max_num_batched_tokens: null max_num_seqs: null +# path to initialize vllm config +vllm_config_path: 'src/maxtext/configs/inference/vllm.yml' # ====== Checkpoint Configuration ====== enable_checkpointing: True diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e651293a19..f4ef37a099 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -248,7 +248,7 @@ class ProfilerType(str, Enum): "llama4-17b-16e", "llama4-17b-128e", "olmo3-7b", - 'olmo3-7b-pt', + "olmo3-7b-pt", "olmo3-32b", ] @@ -1544,6 +1544,7 @@ class RLHardware(BaseModel): rollout_tensor_parallelism: int = Field( -1, description="Tensor parallelism per replica for rollout. If not specified, it will be auto-determined." ) + rollout_expert_parallelism: int = Field(1, description="Expert parallelism per replica for rollout") class VLLM(BaseModel): @@ -1557,6 +1558,9 @@ class VLLM(BaseModel): max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.") vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.") vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.") + vllm_config_path: str = Field( + "src/maxtext/configs/inference/vllm.yml", description="path to yaml file for loading vLLM config." + ) class RL(BaseModel): @@ -2447,6 +2451,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "expert": self.ici_expert_parallelism, "autoregressive": self.ici_autoregressive_parallelism, "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads + "attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP } self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] @@ -2466,6 +2471,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de "expert": self.dcn_expert_parallelism, "autoregressive": self.dcn_autoregressive_parallelism, "attn_dp": 1, # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads + "attn_dp_expert": 1, # initialized to 1, vLLM will auto calculate this value based on EP } self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] diff --git a/src/maxtext/vllm_decode.py b/src/maxtext/vllm_decode.py index 2e532d63af..00f023943b 100644 --- a/src/maxtext/vllm_decode.py +++ b/src/maxtext/vllm_decode.py @@ -80,15 +80,21 @@ flags.DEFINE_integer("max_prefill_length", 512, "Maximum prefill length.") flags.DEFINE_float("gpu_memory_utilization", 0.72, "Fraction of GPU memory to be used for the model executor.") +# vllm config variables +flags.DEFINE_integer("vllm_swap_space", 2, "per device swap space in GB") +flags.DEFINE_integer("vllm_async_scheduling", 1, "Async DP Scheduler for vLLM") + # Decoding flags.DEFINE_bool("use_tunix", False, "Whether to use Tunix for vLLM decoding.") flags.DEFINE_string("prompt", "Suggest some famous landmarks in London.", "The prompt to decode.") flags.DEFINE_integer("decode_sampling_temperature", 0, "Temperature for sampling.") flags.DEFINE_integer("decode_sampling_nucleus_p", 1, "Nucleus sampling probability.") flags.DEFINE_integer("decode_sampling_top_k", 1, "Top-k sampling probability.") - -# Mark required flags -flags.mark_flag_as_required("hf_config_path") +flags.DEFINE_string( + "vllm_config_path", + "src/MaxText/configs/vllm.yml", + "Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml.", +) def decode_with_vllm( @@ -103,6 +109,8 @@ def decode_with_vllm( max_prefill_length: int, max_target_length: int, gpu_memory_utilization: float, + vllm_swap_space: int, + vllm_async_scheduling: int, enable_expert_parallel: bool, prompt: str, decode_sampling_temperature: float, @@ -145,6 +153,8 @@ def decode_with_vllm( vllm_args["enable_expert_parallel"] = enable_expert_parallel vllm_args["hf_config_path"] = hf_config_path vllm_args["gpu_memory_utilization"] = gpu_memory_utilization + vllm_args["swap_space"] = vllm_swap_space + vllm_args["async_scheduling"] = vllm_async_scheduling # Prepare MaxText and sharding configs (Parallelism is dynamic) vllm_args["additional_config"]["maxtext_config"] = { @@ -291,12 +301,15 @@ def main(argv: Sequence[str]) -> None: max_target_length=FLAGS.max_target_length, max_prefill_length=FLAGS.max_prefill_length, gpu_memory_utilization=FLAGS.gpu_memory_utilization, + vllm_swap_space=FLAGS.vllm_swap_space, + vllm_async_scheduling=FLAGS.vllm_async_scheduling, enable_expert_parallel=FLAGS.enable_expert_parallel, prompt=FLAGS.prompt, decode_sampling_temperature=FLAGS.decode_sampling_temperature, decode_sampling_nucleus_p=FLAGS.decode_sampling_nucleus_p, decode_sampling_top_k=FLAGS.decode_sampling_top_k, debug_sharding=FLAGS.debug_sharding, + vllm_config_path=FLAGS.vllm_config_path, )