Skip to content

Commit c8bd7ba

Browse files
committed
Update images, add annotation. Update sample config YAML to use Python image.
1 parent 9a9c97b commit c8bd7ba

8 files changed

Lines changed: 58 additions & 96 deletions

File tree

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ The user workload is typically on a Vertex AI notebook, so users can connect to
1919
- docker version 17.03+.
2020
- kubectl version v1.11.3+.
2121
- Access to a Kubernetes v1.11.3+ cluster.
22-
- JobSet //ToDo(roshanin) install JobSet
2322

2423
### To Deploy on the cluster
2524
**Build and push your image to the location specified by `IMG`:**

api/v1/pathwaysjob_types.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ type PathwaysJobSpec struct {
6161
// Maximum number of times the JobSet is restarted.
6262
MaxRestarts int32 `json:"maxRestarts,omitempty"`
6363

64-
// PathwaysDir is a persistent location like GCS at which temporary
64+
// PathwaysDir is a persistent GCS location at which temporary
6565
// Pathways artifacts can be stored like HBM state during interruptions.
6666
// Currently, Pathways supports a precreated GCS directory only.
6767
PathwaysDir string `json:"pathwaysDir,omitempty"`
6868

69-
// PathwaysVersion is the version of the Pathways client.
69+
// PathwaysVersion is the version of the Pathways cluster.
70+
// This indicates the version of the Pathways RM, Proxy and Workers.
7071
PathwaysVersion string `json:"pathwaysVersion,omitempty"`
7172

7273
// The list of worker types created for the Pathways Job. Currently only

config/crd/bases/pathways-job.pathways.domain_pathwaysjobs.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8058,12 +8058,14 @@ spec:
80588058
type: integer
80598059
pathwaysDir:
80608060
description: |-
8061-
PathwaysDir is a persistent location like GCS at which temporary
8061+
PathwaysDir is a persistent GCS location at which temporary
80628062
Pathways artifacts can be stored like HBM state during interruptions.
80638063
Currently, Pathways supports a precreated GCS directory only.
80648064
type: string
80658065
pathwaysVersion:
8066-
description: PathwaysVersion is the version of the Pathways client.
8066+
description: |-
8067+
PathwaysVersion is the version of the Pathways cluster.
8068+
This indicates the version of the Pathways RM, Proxy and Workers.
80678069
type: string
80688070
workers:
80698071
description: |-

config/manager/kustomization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ apiVersion: kustomize.config.k8s.io/v1beta1
1818
kind: Kustomization
1919
images:
2020
- name: controller
21-
newName: us-docker.pkg.dev/cloud-tpu-multipod-dev/pathways/pathwaysjob
21+
newName: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/pathwaysjob
2222
newTag: latest

config/samples/jobset_example.yaml

Whitespace-only changes.

config/samples/kustomization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414

1515
## Append samples of your project ##
1616
resources:
17-
- pathways-api_v1_pathwaysjob.yaml
17+
- pathways-job_v1_pathwaysjob.yaml
1818
# +kubebuilder:scaffold:manifestskustomizesamples

config/samples/pathways-job_v1_pathwaysjob.yaml

Lines changed: 10 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
apiVersion: pathways-job.pathways.domain/v1
1616
kind: PathwaysJob
1717
metadata:
18-
name: pathways-trial62
18+
name: pathways-trial
1919
spec:
2020
maxRestarts: 10
2121
workers:
@@ -37,85 +37,14 @@ spec:
3737
- name: JAX_PLATFORMS
3838
value: proxy
3939
- name: JAX_BACKEND_TARGET
40-
value: grpc://pathways-trial62-proxy-0-0.pathways-trial62:29008
41-
image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest
40+
value: grpc://pathways-trial-proxy-0-0.pathways-trial:29008
41+
image: python:3.13
4242
imagePullPolicy: Always
4343
command:
44-
- bash
45-
- -c
46-
- |
47-
(python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=gs://cloud-pathways-staging dataset_path=gs://maxtext-dataset/ steps=10 run_name=roshanin-pathways1 enable_single_controller=true attention=dot_product monitor_goodput=False enable_tensorboard=True enable_checkpointing=False);
48-
volumeMounts:
49-
- mountPath: /tmp
50-
name: shared-tmp
51-
resources:
52-
limits:
53-
cpu: "20"
54-
memory: 90G
55-
56-
57-
58-
# #Pod template for inference, colocate mode.
59-
60-
61-
# deploymentMode: colocate
62-
# template: # UserPodTemplate
63-
# spec:
64-
# containers:
65-
# - name: jetstream
66-
# image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest
67-
# imagePullPolicy: Always
68-
# ports:
69-
# - containerPort: 9000
70-
# env:
71-
# - name: XCLOUD_ENVIRONMENT
72-
# value: GCP
73-
# - name: JAX_PLATFORMS
74-
# value: proxy
75-
# - name: JAX_BACKEND_TARGET
76-
# value: grpc://pathways-trial61-leader-0-0.pathways-trial61:29008
77-
# command:
78-
# - bash
79-
# - -c
80-
# - 'echo Start: $(date);
81-
# _sigterm() ( kill -SIGTERM $! 2>/dev/null;);
82-
# trap _sigterm SIGTERM;
83-
# (JAX_TRACEBACK_FILTERING=off python3 MaxText/maxengine_server.py
84-
# MaxText/configs/inference_jetstream.yml tokenizer_path=assets/tokenizer.llama2
85-
# load_parameters_path=gs://runner-maxtext-logs/2024-05-07-23-34/unscanned_chkpt/checkpoints/0/items
86-
# max_prefill_predict_length=1024 max_target_length=2048 async_checkpointing=false
87-
# model_name=''llama2-70b'' steps=1 ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1
88-
# ici_tensor_parallelism=1 scan_layers=false weight_dtype=bfloat16
89-
# per_device_batch_size=2) & PID=$!;
90-
# while kill -0 $PID 2>/dev/null;
91-
# do sleep 5;
92-
# done;
93-
# wait $PID;
94-
# EXIT_CODE=$?
95-
# echo EXIT_CODE=$EXIT_CODE;
96-
# echo End sleep: $(date);
97-
# sleep infinity;'
98-
# - name: tester
99-
# image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest
100-
# imagePullPolicy: Always
101-
# command:
102-
# - bash
103-
# - -c
104-
# - 'echo Start: $(date);
105-
# _sigterm() ( kill -SIGTERM $! 2>/dev/null;);
106-
# trap _sigterm SIGTERM;
107-
# for i in {1..5}; do
108-
# echo Sending request $i;
109-
# time python3 JetStream/jetstream/tools/requester.py --tokenizer assets/tokenizer.llama2 --max_tokens=16 --server=0.0.0.0 --text=\"why earth is round\";
110-
# EXIT_CODE=$?;
111-
# echo Completed request;
112-
# echo EXIT_CODE=$EXIT_CODE;
113-
# if [[ $EXIT_CODE -ne 0 ]]; then
114-
# break;
115-
# fi;
116-
# done;
117-
# echo Last EXIT_CODE=$EXIT_CODE;
118-
# echo End sleep: $(date);
119-
# sleep infinity;'
120-
# securityContext:
121-
# privileged: true
44+
- /bin/sh
45+
- -c
46+
- |
47+
pip install --upgrade pip
48+
pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
49+
pip install pathwaysutils
50+
python -c "import jax; import pathwaysutils; print(\"Number of JAX devices is\", len(jax.devices()))"

internal/controller/pathwaysjob_controller.go

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ var (
5858
var WorkerTypeToTPUVersionMap = map[string]string{
5959
"tpu-v6e-slice": "tpuv6e",
6060
"tpu-v5p-slice": "tpuv5",
61-
"tpu-v5-lite-podslice": "tpuv5",
61+
"tpu-v5-lite-podslice": "tpuv5e",
6262
"tpu-v4-podslice": "tpuv4",
6363
}
6464

@@ -278,17 +278,15 @@ func validateTPUTopologyWithWorkerType(ctx context.Context, tpuGKEAcceleratorTyp
278278
// Calculate the number of VMs based on the Topology (- used in completions/parallelisms)
279279
func calculateVMsFromTopology(topology string) int32 {
280280
parts := strings.Split(topology, "x") // Examples - 2x2x4 or 4x4
281-
if len(parts) < 2 {
282-
return 0
283-
}
284281
// Calculate the number of chips based on the Topology.
282+
// The topology must have already been validated with the worker type.
285283
chips := 1
286284
for _, part := range parts {
287285
num, _ := strconv.Atoi(part)
288286
chips *= num
289287
}
290288
vms := 1
291-
chipsperVM := 4
289+
chipsperVM := 4 // ToDo (roshanin): Add support for VMs with 8 chips per host.
292290
if chips >= chipsperVM {
293291
vms = chips / chipsperVM
294292
}
@@ -308,13 +306,24 @@ func calculateTPUInfo(ctx context.Context, pw *pathwaysjob.PathwaysJob) error {
308306
return nil
309307
}
310308

309+
// Construct image tag based on Pathways version
310+
func makeImageTagUsingPathwaysVersion(pw *pathwaysjob.PathwaysJob) string {
311+
var tag string
312+
if pw.Spec.PathwaysVersion != "" {
313+
tag = string(pw.Spec.PathwaysVersion)
314+
} else {
315+
tag = "latest"
316+
}
317+
return tag
318+
}
319+
311320
// Constructs the Pathways resource manager container spec for the underlying JobSet
312321
func MakeResourceManagerContainer(pw *pathwaysjob.PathwaysJob, rmJobName string) (*corev1.Container, error) {
313322
truth := true
314323

315324
rmContainerSpec := corev1.Container{
316325
Name: "pathways-rm",
317-
Image: "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/sanitized_server:latest",
326+
Image: fmt.Sprintf("us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:%s", makeImageTagUsingPathwaysVersion(pw)),
318327
ImagePullPolicy: "Always",
319328
SecurityContext: &corev1.SecurityContext{Privileged: &truth},
320329
Args: []string{
@@ -342,7 +351,7 @@ func MakeProxyContainer(pw *pathwaysjob.PathwaysJob, rmJobName string) (*corev1.
342351

343352
proxyContainerSpec := corev1.Container{
344353
Name: "pathways-proxy",
345-
Image: "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/sanitized_proxy_server:latest",
354+
Image: fmt.Sprintf("us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:%s", makeImageTagUsingPathwaysVersion(pw)),
346355
ImagePullPolicy: "Always",
347356
SecurityContext: &corev1.SecurityContext{Privileged: &truth},
348357
Args: []string{
@@ -360,6 +369,15 @@ func MakeProxyContainer(pw *pathwaysjob.PathwaysJob, rmJobName string) (*corev1.
360369
func MakeWorkerJob(ctx context.Context, pw *pathwaysjob.PathwaysJob, rmJobName string) (jobsetv1alpha2.ReplicatedJob, error) {
361370
truth := true
362371
volumeSourceType := corev1.HostPathDirectoryOrCreate
372+
objectMeta := metav1.ObjectMeta{}
373+
374+
if pw.Spec.Controller.DeploymentMode == pathwaysjob.Default {
375+
objectMeta = metav1.ObjectMeta{
376+
Annotations: map[string]string{
377+
"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool",
378+
},
379+
}
380+
}
363381

364382
workerJob := jobsetv1alpha2.ReplicatedJob{
365383
Name: "worker",
@@ -370,11 +388,12 @@ func MakeWorkerJob(ctx context.Context, pw *pathwaysjob.PathwaysJob, rmJobName s
370388
Completions: ptr.To(int32(NumVMs)), // number of workers remember to change
371389
Parallelism: ptr.To(int32(NumVMs)), // number of workers remember to change
372390
Template: corev1.PodTemplateSpec{
391+
ObjectMeta: objectMeta,
373392
Spec: corev1.PodSpec{
374393
Containers: []corev1.Container{
375394
{
376395
Name: "pathways-worker",
377-
Image: "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/sanitized_server:latest",
396+
Image: fmt.Sprintf("us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:%s", makeImageTagUsingPathwaysVersion(pw)),
378397
ImagePullPolicy: "Always",
379398
SecurityContext: &corev1.SecurityContext{Privileged: &truth},
380399
Args: []string{
@@ -549,6 +568,10 @@ func MakeJobsForDefaultDeployment(ctx context.Context, pw *pathwaysjob.PathwaysJ
549568
Parallelism: ptr.To(int32(1)),
550569
Template: corev1.PodTemplateSpec{
551570
Spec: corev1.PodSpec{
571+
NodeSelector: map[string]string{ // predictably place RM on CPUs
572+
"cloud.google.com/machine-family": "n2",
573+
"node.kubernetes.io/instance-type": "n2-standard-64",
574+
},
552575
HostNetwork: true, // For performance == McJAX
553576
DNSPolicy: corev1.DNSClusterFirstWithHostNet, // For performance == McJAX
554577
Tolerations: []corev1.Toleration{
@@ -586,6 +609,10 @@ func MakeJobsForDefaultDeployment(ctx context.Context, pw *pathwaysjob.PathwaysJ
586609
Parallelism: ptr.To(int32(1)),
587610
Template: corev1.PodTemplateSpec{
588611
Spec: corev1.PodSpec{
612+
NodeSelector: map[string]string{ // predictably place RM on CPUs
613+
"cloud.google.com/machine-family": "n2",
614+
"node.kubernetes.io/instance-type": "n2-standard-64",
615+
},
589616
HostNetwork: true, // For performance == McJAX
590617
DNSPolicy: corev1.DNSClusterFirstWithHostNet, // For performance == McJAX
591618
Tolerations: []corev1.Toleration{
@@ -629,6 +656,10 @@ func MakeJobsForDefaultDeployment(ctx context.Context, pw *pathwaysjob.PathwaysJ
629656
// },
630657
Template: corev1.PodTemplateSpec{
631658
Spec: corev1.PodSpec{
659+
NodeSelector: map[string]string{ // predictably place RM on CPUs
660+
"cloud.google.com/machine-family": "n2",
661+
"node.kubernetes.io/instance-type": "n2-standard-64",
662+
},
632663
HostNetwork: true, // For performance == McJAX
633664
DNSPolicy: corev1.DNSClusterFirstWithHostNet, // For performance == McJAX
634665
Tolerations: []corev1.Toleration{ // tolerations are important here to not run this job on TPUs

0 commit comments

Comments
 (0)