Multislice- und elastisches Training auf TPUs mit Ray Train in GKE

In dieser Anleitung erfahren Sie, wie Sie Large Language Models (LLMs) wie Llama 3 70B in Google Kubernetes Engine (GKE) mit MaxText, Ray Train und Multislice-Trillium-TPUs trainieren. In dieser Anleitung wird der gesamte Prozess von der Konfiguration der erforderlichen sekundären Rechenzentrumsnetzwerke bis zum Einreichen und erfolgreichen Ausführen einer verteilten Trainingsarbeitslast auf 32 physischen TPU-Chips beschrieben.

Diese Anleitung richtet sich an Plattformadministratoren, Betreiber und KI-Spezialisten, die erfahren möchten, wie sie die Speicher- und Netzwerkherausforderungen beim Trainieren von Modellen mit 70 Milliarden Parametern auf verteilten TPU-Slices mit mehreren Hosts bewältigen können.

Hintergrund

Die Kombination aus GKE, KubeRay, MaxText und TPUs bietet eine leistungsstarke und skalierbare Plattform für das Training großer Modelle. In diesem Abschnitt werden die in diesem Leitfaden verwendeten Schlüsseltechnologien beschrieben:

JAX

JAX ist eine Python-Bibliothek für die beschleunigerorientierte Array-Berechnung und Programmtransformation, die den XLA-Compiler verwendet, um hochoptimierten Code zu erstellen, der effizient auf Beschleunigern skaliert.

MaxText

MaxText ist ein leistungsstarkes Open-Source-LLM-Framework, das auf Skalierbarkeit und Anpassbarkeit ausgelegt ist. MaxText basiert auf JAX und ist für die effiziente Ausführung auf Cloud TPUs optimiert.

TPUs

Tensor Processing Units (TPUs) sind von Google speziell entwickelte Beschleuniger zur Optimierung von Arbeitslasten für maschinelles Lernen. Im Gegensatz zu CPUs für allgemeine Zwecke oder GPUs für die Parallelverarbeitung sind TPUs hochgradig auf die massiven Matrix- und Tensorberechnungen spezialisiert, die die Grundlage von Deep Learning bilden. Dadurch sind sie für diese spezielle Aufgabe effizient. Der Hauptvorteil von TPUs ist die Leistung bei großem Umfang.

In dieser Anleitung wird TPU Trillium, die sechste Generation von TPUs, in einem Multislice-Bereitstellungsmuster verwendet. Bei Cloud TPU-Multislice kommunizieren zwei oder mehr Cloud TPU-Slices über das Rechenzentrumsnetzwerk (Data Center Network, DCN). Multislice ermöglicht ein kostengünstiges Full-Stack-Training mit nahezu linearer Skalierung bis zu Zehntausenden von TPU-Chips. Weitere Informationen zu Multislice finden Sie unter Cloud TPU Multislice – Übersicht.

KubeRay

KubeRay ist ein Kubernetes-Operator, der eine einheitliche Möglichkeit zum Bereitstellen, Verwalten und Überwachen von Ray-Anwendungen in Kubernetes bietet. Der KubeRay-Operator wird über das Ray on GKE-Add-on installiert und verwaltet. Dies ist die empfohlene Methode zum Bereitstellen und Verwalten von Ray-Clustern in GKE.

GKE Dynamic Resource Allocation Network (DRANET)

GKE DRANET (Dynamic Resource Allocation Network) ist eine Funktion, mit der leistungsstarke Netzwerkgeräte dynamisch an Pods angehängt werden. Dabei wird das Standard-Kubernetes-Netzwerk umgangen und eine hohe Leistung über das DCN ermöglicht.

Ziele

In dieser Anleitung wird Folgendes beschrieben:

  1. Richten Sie einen GKE-Cluster mit zwei TPU-Knotenpools mit mehreren Hosts ein.
  2. Konfigurieren Sie ein sekundäres DCN für die TPU-Kommunikation zwischen Slices.
  3. KubeRay so konfigurieren, dass die Umgebung für verteiltes Training verwaltet wird.
  4. Stellen Sie eine benutzerdefinierte RayCluster-Ressource mit Dynamic Resource Allocation (DRA) für Netzwerkverbindungen bereit.
  5. Erstellen Sie ein Python-Trainingsscript, in dem Sie JaxTrainer von Ray Train verwenden, um den MaxText-Trainingsloop über die TPU-Slices hinweg zu orchestrieren.
  6. Führen Sie einen Baseline-Trainingsjob für Llama 3 8B aus.
  7. Skalieren Sie auf Llama 3 70B mit 2D-Sharding (Tensor-Parallelität und FSDP) über das DCN.

Hinweis

  • Melden Sie sich in Ihrem Google Cloud -Konto an. Wenn Sie mit Google Cloudnoch nicht vertraut sind, erstellen Sie ein Konto, um die Leistungsfähigkeit unserer Produkte in der Praxis sehen und bewerten zu können. Neukunden erhalten außerdem ein Guthaben von 300 $, um Arbeitslasten auszuführen, zu testen und bereitzustellen.
  • Installieren Sie die Google Cloud CLI.

  • Wenn Sie einen externen Identitätsanbieter (IdP) verwenden, müssen Sie sich zuerst mit Ihrer föderierten Identität in der gcloud CLI anmelden.

  • Führen Sie den folgenden Befehl aus, um die gcloud CLI zu initialisieren:

    gcloud init
  • Erstellen Sie ein Google Cloud Projekt oder wählen Sie eines aus.

    Rollen, die zum Auswählen oder Erstellen eines Projekts erforderlich sind

    • Projekt auswählen: Für die Auswahl eines Projekts ist keine bestimmte IAM-Rolle erforderlich. Sie können jedes Projekt auswählen, für das Ihnen eine Rolle zugewiesen wurde.
    • Projekt erstellen: Zum Erstellen eines Projekts benötigen Sie die Rolle „Projektersteller“ (roles/resourcemanager.projectCreator), die die Berechtigung resourcemanager.projects.create enthält. Weitere Informationen zum Zuweisen von Rollen
    • So erstellen Sie ein Google Cloud -Projekt:

      gcloud projects create PROJECT_ID

      Ersetzen Sie PROJECT_ID durch einen Namen für das Google Cloud -Projekt, das Sie erstellen.

    • Wählen Sie das von Ihnen erstellte Google Cloud Projekt aus:

      gcloud config set project PROJECT_ID

      Ersetzen Sie PROJECT_ID durch den Namen Ihres Projekts in Google Cloud .

  • Prüfen Sie, ob für Ihr Google Cloud Projekt die Abrechnung aktiviert ist.

  • Aktivieren Sie die erforderlichen APIs:

    Rollen, die zum Aktivieren von APIs erforderlich sind

    Zum Aktivieren von APIs benötigen Sie die IAM-Rolle „Service Usage-Administrator“ (roles/serviceusage.serviceUsageAdmin), die die Berechtigung serviceusage.services.enable enthält. Weitere Informationen zum Zuweisen von Rollen

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Installieren Sie die Google Cloud CLI.

  • Wenn Sie einen externen Identitätsanbieter (IdP) verwenden, müssen Sie sich zuerst mit Ihrer föderierten Identität in der gcloud CLI anmelden.

  • Führen Sie den folgenden Befehl aus, um die gcloud CLI zu initialisieren:

    gcloud init
  • Erstellen Sie ein Google Cloud Projekt oder wählen Sie eines aus.

    Rollen, die zum Auswählen oder Erstellen eines Projekts erforderlich sind

    • Projekt auswählen: Für die Auswahl eines Projekts ist keine bestimmte IAM-Rolle erforderlich. Sie können jedes Projekt auswählen, für das Ihnen eine Rolle zugewiesen wurde.
    • Projekt erstellen: Zum Erstellen eines Projekts benötigen Sie die Rolle „Projektersteller“ (roles/resourcemanager.projectCreator), die die Berechtigung resourcemanager.projects.create enthält. Weitere Informationen zum Zuweisen von Rollen
    • So erstellen Sie ein Google Cloud -Projekt:

      gcloud projects create PROJECT_ID

      Ersetzen Sie PROJECT_ID durch einen Namen für das Google Cloud -Projekt, das Sie erstellen.

    • Wählen Sie das von Ihnen erstellte Google Cloud Projekt aus:

      gcloud config set project PROJECT_ID

      Ersetzen Sie PROJECT_ID durch den Namen Ihres Projekts in Google Cloud .

  • Prüfen Sie, ob für Ihr Google Cloud Projekt die Abrechnung aktiviert ist.

  • Aktivieren Sie die erforderlichen APIs:

    Rollen, die zum Aktivieren von APIs erforderlich sind

    Zum Aktivieren von APIs benötigen Sie die IAM-Rolle „Service Usage-Administrator“ (roles/serviceusage.serviceUsageAdmin), die die Berechtigung serviceusage.services.enable enthält. Weitere Informationen zum Zuweisen von Rollen

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Weisen Sie Ihrem Nutzerkonto Rollen zu. Führen Sie den folgenden Befehl für jede der folgenden IAM-Rollen einmal aus: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Ersetzen Sie Folgendes:

    • PROJECT_ID: Ihre Projekt-ID.
    • USER_IDENTIFIER: Die Kennung für Ihr Nutzerkonto . Beispiel: myemail@example.com
    • ROLE: Die IAM-Rolle, die Sie Ihrem Nutzerkonto zuweisen.
  • Da in dieser Anleitung TPU Trillium (v6e) verwendet wird, wählen Sie eine Region oder Zone aus, in der sie verfügbar ist. Weitere Informationen finden Sie unter Cloud TPU-Kontingente.

Umgebung vorbereiten

In dieser Anleitung verwenden Sie Cloud Shell. Die in dieser Anleitung verwendeten Befehlszeilentools gcloud, helm und kubectl sind in Cloud Shell vorinstalliert.

  1. Rufen Sie die Google Cloud Console auf.

  2. Klicken Sie oben im Google Cloud Console-Fenster auf die Schaltfläche Cloud Shell aktivieren Button zum Aktivieren von Cloud Shell.

    In einem neuen Frame in derGoogle Cloud Console wird eine Cloud Shell-Sitzung geöffnet und darin eine Eingabeaufforderung angezeigt.

  3. Klonen Sie in Ihrem Terminal das Repository kubernetes-engine-samples:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. Wechseln Sie in das Verzeichnis, das die Beispieldateien enthält:

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Erstellen und aktivieren Sie eine virtuelle Python-Umgebung:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Installieren Sie die Ray-Befehlszeile:

    pip install "ray[default]==2.55.0"
    
  7. Legen Sie die folgenden Umgebungsvariablen fest:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    Ersetzen Sie Folgendes:

    • GS_BUCKET: der Name des Cloud Storage-Buckets.
    • KSA_NAME: Der Name des Kubernetes-ServiceAccount.
    • CLUSTER_NAME ist der Name des neuen Clusters.
    • REGION: Die Region, in der Ihre TPU-Trillium-Kapazität verfügbar ist.
    • ZONE: Die Zone, in der Ihre TPU Trillium-Kapazität verfügbar ist. Weitere Informationen finden Sie unter TPU-Verfügbarkeit in GKE.

Clusternetzwerk für Cloud TPU Multislice konfigurieren

In einem TPU-Slice mit mehreren Hosts kommunizieren TPU-Geräte über die Hochgeschwindigkeitsverbindungen zwischen den Chips. Bei der Ausführung von Multislice-Jobs müssen die TPU-Slices jedoch über das DCN miteinander kommunizieren. Standardmäßige Kubernetes-Pod-Netzwerke können diesen Traffic verlangsamen. Der Maschinentyp ct6e-standard-4t wird von mehreren physischen Netzwerkkarten (NICs) unterstützt. Um die beste Leistung zu erzielen, erstellen Sie zwei zusätzliche VPC-Netzwerke und verwenden GKE DRANET, um sie direkt mit den Ray-Pods zu verbinden.

  1. Erstellen Sie die beiden zusätzlichen VPC-Netzwerke mit einer großen maximalen Übertragungseinheit (MTU):

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. Erstellen Sie die dedizierten Subnetze:

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

GKE-Cluster erstellen

Sie können KubeRay auf TPUs in einem GKE Autopilot- oder Standardcluster konfigurieren. Für eine vollständig verwaltete Kubernetes-Umgebung empfehlen wir die Verwendung eines Autopilot-Clusters. Informationen zum Auswählen des GKE-Betriebsmodus, der für Ihre Arbeitslasten am besten geeignet ist, finden Sie unter GKE-Betriebsmodi.

Wenn Sie GKE Managed DRANET verwenden möchten, muss Ihr Cluster im Autopilot-Modus die Version 1.35.2-gke.1842000 oder höher und im Standardmodus die Version 1.34.1-gke.1829001 oder höher verwenden. In dieser Anleitung wird die Version 1.35.2-gke.1842000 verwendet.

Autopilot

  1. Führen Sie in Cloud Shell den folgenden Befehl aus:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. Konfigurieren Sie kubectl für die Kommunikation mit Ihrem Cluster :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

Standard

  1. Erstellen Sie in Cloud Shell einen Standardcluster, in dem das Add-on Ray-Operator aktiviert ist, indem Sie den folgenden Befehl ausführen:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    Mit diesem Befehl wird auch GcsFuseCsiDriver aktiviert, sodass Pods Cloud Storage-Buckets als lokale Dateisysteme bereitstellen können. Die Erstellung eines Clusters kann einige Minuten dauern.

  2. Konfigurieren Sie kubectl für die Kommunikation mit Ihrem Cluster:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. Erstellen Sie den ersten TPU-Slice-Knotenpool mit mehreren Hosts mit aktiviertem GKE DRANET:

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. Erstellen Sie den zweiten TPU-Slice-Knotenpool:

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE stellt einen Knotenpool mit vier TPU Trillium-VMs (v6e) bereit, die zusammen als TPU-Slice mit mehreren Hosts mit einer 4x4-Topologie konfiguriert sind. Dieser Knotenpool ist für verteilte Trainingsarbeitslasten bereit.

In einem GKE-Cluster mit aktiviertem Ray-Operator werden KubeRay und der KubeRay-TPU-Webhook automatisch in Ihrem Cluster installiert.

Cloud Storage-Bucket und Dienstkonto konfigurieren

  1. Erstellen Sie einen Cloud Storage-Bucket für freigegebene Prüfpunkte zwischen den TPU-Knoten mit mehreren Hosts.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. So aktivieren Sie den Zugriff auf den Cloud Storage-Bucket:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Fügen Sie dem Dienstkonto die erforderlichen IAM-Richtlinienbindungen hinzu, um den Zugriff auf den Cloud Storage-Bucket zu ermöglichen:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Trainingsskript erstellen

Im maxtext_multi_slice_trainer.py-Script wird der JaxTrainer von Ray Train verwendet, um einen verteilten MaxText-Trainingsjob auf zwei TPU-Slices auszuführen. Das Skript konfiguriert die Trainingsumgebung für acht TPU-Worker mit mehreren Hosts und führt den MaxText-Trainingsjob auf jedem Worker-Knoten aus. Die Funktion train_loop_per_worker umschließt den MaxText-Haupteinstiegspunkt und verwendet den verteilten Scheduler von Ray, um den MaxText-Trainer auf einem TPU-Slice mit mehreren Hosts auszuführen:

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

Im vorherigen Skript wird eine JaxTrainer-Instanz definiert, die acht Worker und eine Topologie von 4x4 anfordert. Intern stellt Ray eine SlicePlacementGroup für die beiden TPU-Slices bereit und sorgt dafür, dass die Ray Train-Worker atomar auf beiden Slices ausgeführt werden, mit einem Worker pro Host.

Modell trainieren

  1. Das Manifest ray-cluster.tpu-multi-slice.yaml im aktuellen Verzeichnis definiert die benutzerdefinierte RayCluster-Ressource. Dieses Manifest enthält das DRANET ResourceClaimTemplate, um die Netzwerkgeräte für GKE DRANET und Multislice bereitzustellen:

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    Mit der vorherigen RayCluster-Spezifikation wird eine TPU-Worker-Gruppe mit acht Workern (numOfHosts: 4) pro Replikat und zwei Replikaten erstellt. Jeder Worker fordert vier TPU-Chips (google.com/tpu: "4") an. Die Worker werden jeweils auf einem TPU-Trillium-Knoten (tpu-v6e-slice) geplant, der Teil desselben gemeinsam untergebrachten Multi-Host-Slice ist. KubeRay skaliert alle vier Worker in einem Slice atomar. Die erforderlichen JAX-Umgebungsvariablen sowie die Pod-Affinitäten für die Planung werden von GKE über einen mutierenden Webhook gebootstrapped.

  2. Wenden Sie das Manifest an, um den RayCluster zu erstellen:

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. Prüfen Sie, ob der Cluster bereit ist und ausgeführt wird:

    kubectl get rayclusters maxtext-tpu-cluster
    

    Die Ausgabe sollte in etwa so aussehen:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. Wenn Sie über den Ray-Head-Dienst auf das Ray-Dashboard zugreifen möchten, richten Sie eine Portweiterleitungssitzung ein:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Prüfen Sie, ob der RayCluster von Ihrer lokalen Umgebung aus erreichbar ist:

    ray list nodes --address http://localhost:8265
    

    Die Ausgabe sollte in etwa so aussehen:

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. Laden Sie die grundlegende MaxText-Konfigurationsdatei herunter. Diese Datei ist für das Trainingsskript erforderlich, um die Standard-Hyperparameter des Modells festzulegen:

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. Senden Sie das JaxTrainer-Script an den RayCluster und prüfen Sie, ob der RayJob erfolgreich abgeschlossen wurde:

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

Mit dem vorherigen Befehl wird das Python-Script, das den JaxTrainer-Ray-Code aufruft, an den Ray-Cluster gesendet. Der Befehl ray job submit enthält einige MaxText-spezifische Argumente, die an die Modellkonfiguration übergeben werden.

Im Terminal sollte für den Llama 3 70B-Job eine ähnliche Ausgabe angezeigt werden:

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

Elastisches Multislice-Training auf Spot-VMs ausführen

Bei der Verwendung von stark nachgefragten Beschleunigern wie TPUs können die Kosten durch die Nutzung von Spot-VMs erheblich gesenkt werden. Spot-VMs können jedoch unerwartet vorzeitig beendet werden.

Ray Train unterstützt elastisches Training, sodass die Anzahl der beteiligten TPU-Slices in Ihrem Job dynamisch skaliert werden kann, ohne dass der Job fehlschlägt. Wenn ein Slice vorzeitig beendet wird, pausiert Ray die Trainingsschleife, wartet, bis die verbleibenden Worker sich neu organisiert haben, stellt den letzten MaxText-Checkpoint wieder her und setzt das Training mit dem kleineren Umfang fort.

Wenn Sie elastisches Training aktivieren möchten, ändern Sie den Parameter num_workers in Ihrem ScalingConfig von einer statischen Ganzzahl in ein Tupel, das (minimum_workers, maximum_workers) darstellt. Fügen Sie außerdem ein FailureConfig(max_failures=3) in die RunConfig ein. Dadurch wird Ray Train angewiesen, den Trainingsdurchlauf bis zu dreimal zu wiederholen, anstatt den Job vollständig zu beenden, wenn ein Worker unterbrochen wird.

Ray Train-Skript aktualisieren

  1. Das maxtext_elastic_trainer.py-Skript im aktuellen Verzeichnis ermöglicht elastisches Training. Beachten Sie, dass num_workers=(4,8) festgelegt wird. Dadurch wird Ray angewiesen, fortzufahren, wenn mindestens ein Slice mit 16 Chips (vier Worker) verfügbar ist, aber nach Möglichkeit auf zwei Slices (acht Worker) zu skalieren. Sie enthält ein FailureConfig, um elastisches Training zu ermöglichen, die Anzahl der Wiederholungen zu definieren und dafür zu sorgen, dass der Job bei Unterbrechungen fortgesetzt wird:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Senden Sie den Job mit der Ray Job CLI. Achten Sie darauf, dass Sie eine eindeutige run_name angeben, damit die Prüfpunkte nicht mit früheren Läufen in Konflikt geraten.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. Wenn Sie eine Knotenbeendigung oder ein vorzeitiges Beenden während des Trainings simulieren möchten, löschen Sie einen Pod.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

Im Terminal wird ein Worker-Fehler protokolliert, aber der Orchestrierungscontroller hält den Job aktiv und setzt ihn automatisch am /data/rayjob-elastic-8b/checkpoints-Checkpoint fort, sobald die Mindesttopologie verfügbar ist.

Da MaxText das Gerätenetz bei der Wiederaufnahme dynamisch neu berechnet, müssen Sie keine benutzerdefinierte Logik schreiben, um das erneute Sharding von Prüfpunkten zu verarbeiten, wenn die Topologie kleiner wird. Der Orbax-Checkpointer von JAX führt automatisch ein Resharding der gespeicherten Gewichte in das neue physische Layout durch, bevor der Trainingszyklus fortgesetzt wird. In der folgenden Ausgabe ist zu sehen, dass der Ray Train-Controller neu verfügbare TPU-Ressourcen im Cluster erkennt und während des Trainings eine Skalierung von einem Slice (vier Worker) auf zwei Slices (acht Worker) durchführt.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

Bereinigen

Damit Ihrem Google Cloud -Konto die in dieser Anleitung verwendeten Ressourcen nicht in Rechnung gestellt werden, können Sie entweder das Projekt löschen, das die Ressourcen enthält, oder das Projekt beibehalten und die einzelnen Ressourcen löschen.

  1. Löschen Sie den RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Löschen Sie den GKE-Cluster:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Löschen Sie den Cloud Storage-Bucket:

    gsutil rm -r gs://${GS_BUCKET}
    

Nächste Schritte