Skip to content

Commit 2421a74

Browse files
guptaakacopybara-github
authored andcommitted
Add a script to deploy Pathways service as a JobSet
PiperOrigin-RevId: 896092607
1 parent 2cb53bb commit 2421a74

File tree

2 files changed

+411
-16
lines changed

2 files changed

+411
-16
lines changed
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""Deploys Pathways service to a Kubernetes cluster using a JobSet template."""
2+
3+
from collections.abc import Callable, Sequence
4+
import dataclasses
5+
import logging
6+
import math
7+
import os
8+
import string
9+
from typing import Any
10+
from absl import app
11+
from absl import flags
12+
from kubernetes import client
13+
from kubernetes import config
14+
import yaml
15+
16+
_logger = logging.getLogger(__name__)
17+
18+
# Flag definitions
19+
FLAGS = flags.FLAGS
20+
_JOBSET_NAME = flags.DEFINE_string(
21+
"jobset_name", "pathways-service", "Name of the JobSet"
22+
)
23+
_JAX_VERSION = flags.DEFINE_string(
24+
"jax_version", "0.9.0", "JAX version (e.g., 0.9.0)"
25+
)
26+
_SERVER_IMAGE = flags.DEFINE_string(
27+
"server_image", None, "Full path to the server Docker image"
28+
)
29+
_TPU_TYPE = flags.DEFINE_enum(
30+
"tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type"
31+
)
32+
_TOPOLOGY = flags.DEFINE_string(
33+
"topology", "2x2", "TPU topology (e.g., 4x8, 2x2x2)"
34+
)
35+
_NUM_SLICES = flags.DEFINE_integer(
36+
"num_slices", 2, "Number of TPU slices"
37+
)
38+
_GCS_BUCKET = flags.DEFINE_string(
39+
"gcs_bucket",
40+
"gs://pathways-test-bucket",
41+
"GCS bucket name for scratch space",
42+
)
43+
_TEMPLATE_FILE = flags.DEFINE_string(
44+
"template_file",
45+
os.path.join(
46+
os.path.dirname(__file__), "yamls/pw-service-example.yaml",
47+
),
48+
"Path to the JobSet YAML template file",
49+
)
50+
_DRY_RUN = flags.DEFINE_boolean(
51+
"dry_run",
52+
False,
53+
"If true, only print the generated YAML without deploying.",
54+
)
55+
56+
57+
@dataclasses.dataclass(frozen=True)
58+
class TPUConfig:
59+
"""Holds configuration details for a specific TPU type."""
60+
machine_type: str
61+
chips_per_vm: int
62+
accelerator_label: str
63+
instance_prefix: str
64+
65+
66+
def _validate_topology(topology):
67+
"""Validates the topology flag format."""
68+
try:
69+
dims = topology.split("x")
70+
if not (2 <= len(dims) <= 3):
71+
return False
72+
for dim in dims:
73+
if not dim.isdigit():
74+
return False
75+
if int(dim) <= 0:
76+
return False
77+
return True
78+
except ValueError:
79+
return False
80+
81+
82+
flags.register_validator(
83+
"topology",
84+
_validate_topology,
85+
message=(
86+
"--topology must be in the format like 'AxB' or 'AxBxC', where A, B, C"
87+
" are positive integers."
88+
),
89+
)
90+
91+
92+
def get_tpu_config(tpu_type: str) -> TPUConfig:
93+
"""Returns a TPUConfig object containing TPU configuration details."""
94+
tpu_configs = {
95+
"v5e": TPUConfig(
96+
machine_type="ct5lp-hightpu-4t",
97+
chips_per_vm=4,
98+
accelerator_label="tpu-v5-lite-podslice",
99+
instance_prefix="tpuv5e",
100+
),
101+
"v5p": TPUConfig(
102+
machine_type="ct5p-hightpu-4t",
103+
chips_per_vm=4,
104+
accelerator_label="tpu-v5p-slice",
105+
instance_prefix="tpuv5p",
106+
),
107+
"v6e": TPUConfig(
108+
machine_type="ct6e-standard-4t",
109+
chips_per_vm=4,
110+
accelerator_label="tpu-v6e-slice",
111+
instance_prefix="tpuv6e",
112+
),
113+
"tpu7x": TPUConfig(
114+
machine_type="tpu7x-standard-4t",
115+
chips_per_vm=4,
116+
accelerator_label="tpu-v7-slice",
117+
instance_prefix="tpu7x",
118+
),
119+
}
120+
if tpu_type not in tpu_configs:
121+
raise ValueError(
122+
f"Unsupported TPU type: {tpu_type}. Supported types are:"
123+
f" {list(tpu_configs.keys())}"
124+
)
125+
return tpu_configs[tpu_type]
126+
127+
128+
def calculate_vms_per_slice(topology: str, chips_per_vm: int) -> int:
129+
"""Calculates the number of VMs per slice based on the topology."""
130+
try:
131+
dims = [int(d) for d in topology.split("x")]
132+
total_chips = math.prod(dims)
133+
if total_chips % chips_per_vm != 0:
134+
raise ValueError(
135+
f"Total chips ({total_chips}) in topology {topology} is not divisible"
136+
f" by chips_per_vm ({chips_per_vm})"
137+
)
138+
return total_chips // chips_per_vm
139+
except ValueError as e:
140+
raise ValueError(
141+
f"Invalid topology format: {topology}. Expected format like 'AxB' or"
142+
f" 'AxBxC'. {e}"
143+
) from e
144+
145+
146+
def load_and_substitute_template(
147+
template_path: str, context: dict[str, Any]
148+
) -> dict[str, Any]:
149+
"""Loads and substitutes the string.Template from the given path."""
150+
try:
151+
with open(template_path, "r") as f:
152+
template_str = f.read()
153+
except OSError as err:
154+
raise ValueError(
155+
f"Could not read template file: {template_path}: {err}"
156+
) from err
157+
158+
_logger.info("Template file: %s", template_path)
159+
_logger.info("Context: %s", context)
160+
template = string.Template(template_str)
161+
_logger.info("Template: %s", template)
162+
substituted_yaml = template.substitute(context)
163+
return yaml.safe_load(substituted_yaml)
164+
165+
166+
def deploy_jobset(jobset_yaml: dict[str, Any]) -> None:
167+
"""Deploys the JobSet to the current Kubernetes cluster."""
168+
try:
169+
config.load_kube_config()
170+
api = client.CustomObjectsApi()
171+
api.create_namespaced_custom_object(
172+
group="jobset.x-k8s.io",
173+
version="v1alpha2",
174+
namespace=jobset_yaml["metadata"]["namespace"],
175+
body=jobset_yaml,
176+
plural="jobsets",
177+
)
178+
_logger.info(
179+
"JobSet '%s' created successfully.", jobset_yaml["metadata"]["name"]
180+
)
181+
except client.rest.ApiException:
182+
_logger.exception("Error creating JobSet")
183+
except config.ConfigException:
184+
_logger.exception("Error loading Kubernetes configuration")
185+
186+
187+
def run_deployment(
188+
tpu_type,
189+
topology,
190+
num_slices,
191+
jobset_name,
192+
gcs_bucket,
193+
server_image,
194+
template_file,
195+
dry_run,
196+
deploy_func: Callable[[dict[str, Any]], None] = deploy_jobset,
197+
) -> None:
198+
"""Executes the deployment logic."""
199+
tpu_config = get_tpu_config(tpu_type)
200+
vms_per_slice = calculate_vms_per_slice(topology, tpu_config.chips_per_vm)
201+
202+
context = {
203+
"JOBSET_NAME": jobset_name,
204+
"SERVER_IMAGE": server_image,
205+
"GCS_SCRATCH_LOCATION": gcs_bucket,
206+
"NUM_SLICES": num_slices,
207+
"INSTANCE_TYPE": f"{tpu_config.instance_prefix}:{topology}",
208+
"VMS_PER_SLICE": vms_per_slice,
209+
"CHIPS_PER_VM": tpu_config.chips_per_vm,
210+
"ACCELERATOR_LABEL": tpu_config.accelerator_label,
211+
"TOPOLOGY": topology,
212+
}
213+
214+
jobset_config = load_and_substitute_template(template_file, context)
215+
216+
_logger.info("--- Generated JobSet YAML ---")
217+
_logger.info("\n%s", yaml.dump(jobset_config))
218+
219+
if not dry_run:
220+
_logger.info("Deploying JobSet...")
221+
deploy_func(jobset_config)
222+
else:
223+
_logger.info("Dry run mode, not deploying.")
224+
225+
226+
def main(argv: Sequence[str]) -> None:
227+
if len(argv) > 1:
228+
raise app.UsageError("Too many command-line arguments.")
229+
230+
try:
231+
if (
232+
flags.FLAGS["jax_version"].present
233+
and flags.FLAGS["server_image"].present
234+
):
235+
raise ValueError("Cannot provide both --jax_version and --server_image")
236+
237+
if _SERVER_IMAGE.value:
238+
server_image = _SERVER_IMAGE.value
239+
else:
240+
server_image = f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-{_JAX_VERSION.value}"
241+
242+
run_deployment(
243+
tpu_type=_TPU_TYPE.value,
244+
topology=_TOPOLOGY.value,
245+
num_slices=_NUM_SLICES.value,
246+
jobset_name=_JOBSET_NAME.value,
247+
gcs_bucket=_GCS_BUCKET.value,
248+
server_image=server_image,
249+
template_file=_TEMPLATE_FILE.value,
250+
dry_run=_DRY_RUN.value,
251+
)
252+
except ValueError as e:
253+
_logger.exception("Error: %s", e)
254+
except FileNotFoundError:
255+
_logger.exception(
256+
"Error: Template file not found at %s", _TEMPLATE_FILE.value
257+
)
258+
259+
260+
if __name__ == "__main__":
261+
app.run(main)

0 commit comments

Comments
 (0)