Multislice and elastic training on TPUs using Ray Train on GKE

This tutorial shows you how to train large language models (LLMs) like Llama 3 70B on Google Kubernetes Engine (GKE) using MaxText, Ray Train, and Multislice Trillium TPUs. This tutorial provides a complete, end-to-end walkthrough, from configuring the necessary secondary data center networking to submitting and successfully running a distributed training workload across 32 physical TPU chips.

This tutorial is for Platform admins, operators, and AI specialists who want to learn how to overcome the memory and networking challenges of training 70-billion parameter models on distributed, multi-host TPU slices.

Background

The combination of GKE, KubeRay, MaxText, and TPUs provides a powerful and scalable platform for large-scale model training. This section describes the key technologies used in this guide:

JAX

JAX is a Python library for accelerator-oriented array computation and program transformation, utilizing the XLA compiler to create highly optimized code that scales efficiently on accelerators.

MaxText

MaxText is a high-performance, open-source LLM framework designed for scalability and customizability. MaxText is built on top of JAX and is optimized to run efficiently on Cloud TPUs.

TPUs

Tensor Processing Units (TPUs) are custom-designed accelerators created by Google to optimize machine learning workloads. Unlike general-purpose CPUs or parallel-processing GPUs, TPUs are highly specialized for the massive matrix and tensor computations at the foundation of deep learning, making them efficient at this specific task. The primary advantage of TPUs is performance at scale.

This tutorial uses TPU Trillium, the sixth generation of TPUs, in a Multislice deployment pattern. Cloud TPU Multislice is where two or more Cloud TPU slices communicate over the data center network (DCN). Multislice enables full-stack, cost-effective, large scale training with near-linear scaling up to tens of thousands of TPU chips. For more information about Multislice, see Cloud TPU Multislice Overview.

KubeRay

KubeRay is a Kubernetes operator that provides a unified way to deploy, manage, and monitor Ray applications on Kubernetes. The KubeRay operator is installed and managed through the Ray on GKE add-on, which is the recommended way to deploy and manage Ray clusters on GKE.

GKE Dynamic Resource Allocation Network (DRANET)

GKE DRANET (Dynamic Resource Allocation Network) is a feature that dynamically attaches high-performance network devices to Pods, bypassing standard Kubernetes networking and enabling high performance over the DCN.

Objectives

This tutorial shows you how to do the following:

  1. Set up a GKE cluster with two multi-host TPU node pools.
  2. Configure a secondary DCN for cross-slice TPU communication.
  3. Configure KubeRay to manage the distributed training environment.
  4. Deploy a RayCluster custom resource by using Dynamic Resource Allocation (DRA) for network attachments.
  5. Create a Python training script by utilizing Ray Train's JaxTrainer to orchestrate the MaxText training loop across the TPU slices.
  6. Run a baseline Llama 3 8B training job.
  7. Scale up to Llama 3 70B utilizing 2D sharding (Tensor Parallelism and FSDP) over the DCN.

Before you begin

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • If you're using an external identity provider (IdP), you must first sign in to the gcloud CLI with your federated identity.

  • To initialize the gcloud CLI, run the following command:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required APIs:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Install the Google Cloud CLI.

  • If you're using an external identity provider (IdP), you must first sign in to the gcloud CLI with your federated identity.

  • To initialize the gcloud CLI, run the following command:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required APIs:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • Because this tutorial utilizes TPU Trillium (v6e), select a region or zone with availability. For more information, see Cloud TPU quotas.

Prepare your environment

In this tutorial, you use Cloud Shell. Cloud Shell comes preinstalled with the gcloud, helm, and kubectl command-line tools that are used in this tutorial.

  1. Go to the Google Cloud console.

  2. At the top of the Google Cloud console window, click the Activate Cloud Shell Activate Shell
Button button.

    A Cloud Shell session opens inside a new frame in the Google Cloud console and displays a command-line prompt.

  3. In your terminal, clone the kubernetes-engine-samples repository:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. Change to the directory containing the sample files:

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Create and activate a Python virtual environment:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Install the Ray CLI:

    pip install "ray[default]==2.55.0"
    
  7. Set the following environment variables:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    Replace the following:

    • GS_BUCKET: the name of the Cloud Storage bucket.
    • KSA_NAME: the name of the Kubernetes Service Account.
    • CLUSTER_NAME: the name of the new cluster.
    • REGION: the region where your TPU Trillium capacity is available.
    • ZONE: the zone where your TPU Trillium capacity is available. For more information, see TPU availability in GKE.

Configure cluster networking for Cloud TPU Multislice

Within a multi-host TPU slice, TPU devices communicate over the high-speed inter-chip interconnects. However, when running Multislice jobs, the TPU slices must communicate with each other over the DCN. Standard Kubernetes Pod networks can bottleneck this traffic. The ct6e-standard-4t machine type is backed by multiple physical network interface cards (NICs). To achieve the best performance, you create two additional VPC networks and use GKE DRANET to connect them directly to the Ray Pods.

  1. Create the two additional VPC networks with a large maximum training unit (MTU):

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. Create the dedicated subnets:

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

Create a GKE cluster

You can configure KubeRay on TPUs in a GKE Autopilot or Standard cluster. We recommend that you use a Autopilot cluster for a fully managed Kubernetes experience. To choose the GKE mode of operation that's the best fit for your workloads, see About GKE modes of operation.

To use GKE managed DRANET, your cluster must use version 1.35.2-gke.1842000 or later for Autopilot mode, or 1.34.1-gke.1829001 or later for Standard mode. This tutorial uses version 1.35.2-gke.1842000.

Autopilot

  1. In Cloud Shell, run the following command:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. To communicate with your cluster, configure kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

Standard

  1. In Cloud Shell, create a Standard cluster that enables the Ray operator add-on by running the following command:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    This command also enables the GcsFuseCsiDriver, which allows Pods to mount Cloud Storage buckets as local file systems. The cluster creation might take several minutes.

  2. To communicate with your cluster, configure kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. Create the first multi-host TPU slice node pool with GKE DRANET enabled:

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. Create the second TPU slice node pool:

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE provisions a node pool consisting of four TPU Trillium (v6e) VMs, which are configured together as a multi-host TPU slice that has a 4x4 topology. This node pool is ready for distributed training workloads.

The Ray operator-enabled GKE cluster automatically installs KubeRay and the KubeRay TPU webhook in your cluster.

Configure a Cloud Storage bucket and a service account

  1. Create a Cloud Storage bucket for shared checkpoints between the multi-host TPU nodes.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. To enable access to the Cloud Storage bucket, create a Kubernetes Service Account:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. To enable access to the Cloud Storage bucket, add the required IAM policy bindings to the service account:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Create the training script

The maxtext_multi_slice_trainer.py script uses Ray Train's JaxTrainer to run a distributed MaxText training job across two TPU slices. The script configures the training environment for eight multi-host TPU workers and runs the MaxText training job on each worker node. The train_loop_per_worker function wraps the MaxText main entry point, and uses the Ray's distributed scheduler to execute the MaxText trainer on a multi-host TPU slice:

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

The preceding script defines a JaxTrainer instance requesting eight workers and a topology of 4x4. Internally, Ray provisions a SlicePlacementGroup across the two TPU slices and helps ensure that the Ray Train workers run atomically across both slices, with one worker per host.

Train the model

  1. The ray-cluster.tpu-multi-slice.yaml manifest in the current directory defines the RayCluster custom resource. This manifest includes the DRANET ResourceClaimTemplate to provision the network devices for GKE DRANET and Multislice:

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    The preceding RayCluster spec creates a TPU worker group with eight workers (numOfHosts: 4) per replica, with two replicas. Each worker requests four TPU chips (google.com/tpu: "4"). The workers are each scheduled on a TPU Trillium node (tpu-v6e-slice), which is part of the same colocated multi-host slice. KubeRay scales all four workers in a slice atomically. The required JAX environment variables, as well as Pod Affinities for scheduling, are bootstrapped by GKE through a mutating webhook.

  2. To create the RayCluster, apply the manifest:

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. Verify that the cluster is ready and running:

    kubectl get rayclusters maxtext-tpu-cluster
    

    The output should be similar to the following:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. To access the Ray Dashboard through the Ray head service, establish a port-forwarding session:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verify that the RayCluster is reachable from your local environment:

    ray list nodes --address http://localhost:8265
    

    The output should be similar to the following:

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. Download the base MaxText configuration file. This file is required by the training script to set the model's default hyperparameters:

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. Submit the JaxTrainer script to the RayCluster and check that the RayJob completes successfully:

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

The preceding command submits the Python script, which calls the JaxTrainer Ray code to the RayCluster. The ray job submit command includes some MaxText-specific arguments to pass to the model configuration.

In your terminal, you should see output similar to the following for the Llama 3 70B job:

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

Run Multislice elastic training on Spot VMs

When using highly sought-after accelerators like TPUs, utilizing Spot VMs might significantly reduce costs. However, Spot VMs may be preempted unexpectedly.

Ray Train supports elastic training, which allows your job to dynamically scale the number of participating TPU slices up or down without failing. If a slice is preempted, Ray pauses the training loop, waits for the remaining workers to reorganize, restores from the latest MaxText checkpoint, and resumes training on the smaller footprint.

To enable elastic training, change the num_workers parameter in your ScalingConfig from a static integer to a tuple representing (minimum_workers, maximum_workers). Additionally, add a FailureConfig(max_failures=3) to the RunConfig, which instructs Ray Train to retry the training loop up to 3 times instead of failing the job entirely when a worker is preempted.

Update the Ray Train script

  1. The maxtext_elastic_trainer.py script in the current directory enables elastic training. Notice that it sets num_workers=(4,8), which tells Ray to proceed if at least one 16-chip slice (four workers) is available, but to scale up to two slices (eight workers) if possible. It includes a FailureConfig to enable elastic training, define the number of retries, and help ensure the job survives preemptions:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Submit the job by using the Ray Job CLI. Be sure to provide a unique run_name so the checkpoints don't conflict with previous runs.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. To simulate a node termination or preemption during training, delete a Pod.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

The terminal logs a worker failure, but the orchestration controller keeps the job alive and automatically resumes from the /data/rayjob-elastic-8b/checkpoints checkpoint after the minimum topology is available.

Because MaxText dynamically recalculates the device mesh upon resumption, you don't need to write any custom logic to handle checkpoint re-sharding when the topology shrinks. JAX's Orbax checkpointer will automatically re-shard the saved weights into the new physical layout before continuing the training loop. The following output shows the Ray Train controller detect newly available TPU resources in the cluster and perform a scaling operation from one slice (four workers) to two slices (eight workers) during training.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

  1. Delete the RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Delete the GKE cluster:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Delete the Cloud Storage bucket:

    gsutil rm -r gs://${GS_BUCKET}
    

What's next