Addestramento multislice ed elastico sulle TPU utilizzando Ray Train su GKE

Questo tutorial mostra come addestrare modelli linguistici di grandi dimensioni (LLM) come Llama 3 70B su Google Kubernetes Engine (GKE) utilizzando MaxText, Ray Train e TPU Trillium multislice. Questo tutorial fornisce una procedura dettagliata completa ed end-to-end, dalla configurazione del networking del data center secondario necessario all'invio e all'esecuzione corretta di un workload di addestramento distribuito su 32 chip TPU fisici.

Questo tutorial è rivolto ad amministratori della piattaforma, operatori e specialisti di AI che vogliono imparare a superare le sfide di memoria e networking dell'addestramento di modelli con 70 miliardi di parametri su slice TPU multi-host distribuite.

Sfondo

La combinazione di GKE, KubeRay, MaxText e TPU fornisce una piattaforma potente e scalabile per l'addestramento di modelli su larga scala. Questa sezione descrive le tecnologie chiave utilizzate in questa guida:

JAX

JAX è una libreria Python per il calcolo di array e la trasformazione di programmi orientati agli acceleratori, che utilizza il compilatore XLA per creare codice altamente ottimizzato che viene scalato in modo efficiente sugli acceleratori.

MaxText

MaxText è un framework LLM open source ad alte prestazioni progettato per la scalabilità e la personalizzazione. MaxText è basato su JAX ed è ottimizzato per essere eseguito in modo efficiente sulle Cloud TPU.

TPU

Le Tensor Processing Unit (TPU) sono acceleratori progettati su misura e creati da Google per ottimizzare i carichi di lavoro di machine learning. A differenza delle CPU per uso generico o delle GPU per l'elaborazione parallela, le TPU sono altamente specializzate per i calcoli massicci di matrici e tensori alla base del deep learning, il che le rende efficienti in questo compito specifico. Il vantaggio principale delle TPU è il rendimento su larga scala.

Questo tutorial utilizza TPU Trillium, la sesta generazione di TPU, in un pattern di deployment multislice. Cloud TPU Multislice è la soluzione in cui due o più sezioni di Cloud TPU comunicano tramite la rete del data center (DCN). Multislice consente l'addestramento full stack, economico e su larga scala con scalabilità quasi lineare fino a decine di migliaia di chip TPU. Per maggiori informazioni su Multislice, consulta la panoramica di Cloud TPU Multislice.

KubeRay

KubeRay è un operatore Kubernetes che fornisce un modo unificato per eseguire il deployment, gestire e monitorare le applicazioni Ray su Kubernetes. L'operatore KubeRay viene installato e gestito tramite il componente aggiuntivo Ray su GKE, che è il modo consigliato per eseguire il deployment e gestire i cluster Ray su GKE.

GKE Dynamic Resource Allocation Network (DRANET)

GKE DRANET (Dynamic Resource Allocation Network) è una funzionalità che collega dinamicamente dispositivi di rete ad alte prestazioni ai pod, bypassando il networking Kubernetes standard e consentendo prestazioni elevate sulla DCN.

Obiettivi

Questo tutorial mostra gli aspetti seguenti:

  1. Configura un cluster GKE con due node pool TPU multi-host.
  2. Configura una DCN secondaria per la comunicazione TPU tra sezioni.
  3. Configura KubeRay per gestire l'ambiente di addestramento distribuito.
  4. Esegui il deployment di una risorsa personalizzata RayCluster utilizzando l'allocazione dinamica delle risorse (DRA) per i collegamenti di rete.
  5. Crea uno script di addestramento Python utilizzando JaxTrainer di Ray Train per orchestrare il ciclo di addestramento MaxText nelle sezioni TPU.
  6. Esegui un job di addestramento di base di Llama 3 8B.
  7. Aumenta lo scale up a Llama 3 70B utilizzando lo sharding 2D (parallelismo tensoriale e FSDP) sulla DCN.

Prima di iniziare

  • Accedi al tuo account Google Cloud . Se non conosci Google Cloud, crea un account per valutare le prestazioni dei nostri prodotti in scenari reali. I nuovi clienti ricevono anche 300 $di crediti senza costi per l'esecuzione, il test e il deployment dei workload.
  • Installa Google Cloud CLI.

  • Se utilizzi un provider di identità (IdP) esterno, devi prima accedere a gcloud CLI con la tua identità federata.

  • Per inizializzare gcloud CLI, esegui questo comando:

    gcloud init
  • Crea o seleziona un Google Cloud progetto.

    Ruoli richiesti per selezionare o creare un progetto

    • Seleziona un progetto: la selezione di un progetto non richiede un ruolo IAM specifico. Puoi selezionare qualsiasi progetto per il quale ti è stato concesso un ruolo.
    • Crea un progetto: per creare un progetto, devi disporre del ruolo Autore progetto (roles/resourcemanager.projectCreator), che contiene l'autorizzazione resourcemanager.projects.create. Scopri come concedere i ruoli.
    • Creare un progetto Google Cloud :

      gcloud projects create PROJECT_ID

      Sostituisci PROJECT_ID con un nome per il progetto Google Cloud che stai creando.

    • Seleziona il progetto Google Cloud che hai creato:

      gcloud config set project PROJECT_ID

      Sostituisci PROJECT_ID con il nome del progetto Google Cloud .

  • Verifica che la fatturazione sia abilitata per il tuo progetto Google Cloud .

  • Abilita le API richieste:

    Ruoli richiesti per abilitare le API

    Per abilitare le API, devi disporre del ruolo IAM Amministratore utilizzo dei servizi (roles/serviceusage.serviceUsageAdmin), che include l'autorizzazione serviceusage.services.enable. Scopri come concedere i ruoli.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Installa Google Cloud CLI.

  • Se utilizzi un provider di identità (IdP) esterno, devi prima accedere a gcloud CLI con la tua identità federata.

  • Per inizializzare gcloud CLI, esegui questo comando:

    gcloud init
  • Crea o seleziona un Google Cloud progetto.

    Ruoli richiesti per selezionare o creare un progetto

    • Seleziona un progetto: la selezione di un progetto non richiede un ruolo IAM specifico. Puoi selezionare qualsiasi progetto per il quale ti è stato concesso un ruolo.
    • Crea un progetto: per creare un progetto, devi disporre del ruolo Autore progetto (roles/resourcemanager.projectCreator), che contiene l'autorizzazione resourcemanager.projects.create. Scopri come concedere i ruoli.
    • Creare un progetto Google Cloud :

      gcloud projects create PROJECT_ID

      Sostituisci PROJECT_ID con un nome per il progetto Google Cloud che stai creando.

    • Seleziona il progetto Google Cloud che hai creato:

      gcloud config set project PROJECT_ID

      Sostituisci PROJECT_ID con il nome del progetto Google Cloud .

  • Verifica che la fatturazione sia abilitata per il tuo progetto Google Cloud .

  • Abilita le API richieste:

    Ruoli richiesti per abilitare le API

    Per abilitare le API, devi disporre del ruolo IAM Amministratore utilizzo dei servizi (roles/serviceusage.serviceUsageAdmin), che include l'autorizzazione serviceusage.services.enable. Scopri come concedere i ruoli.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Concedi ruoli al tuo account utente. Esegui il seguente comando una volta per ciascuno dei seguenti ruoli IAM: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Sostituisci quanto segue:

    • PROJECT_ID: il tuo ID progetto.
    • USER_IDENTIFIER: l'identificatore del tuo account utente . Ad esempio: myemail@example.com.
    • ROLE: il ruolo IAM che concedi al tuo account utente.
  • Poiché questo tutorial utilizza TPU Trillium (v6e), seleziona una regione o una zona con disponibilità. Per maggiori informazioni, consulta la sezione Quote di Cloud TPU.

prepara l'ambiente

In questo tutorial utilizzi Cloud Shell. Cloud Shell viene fornito con gli strumenti a riga di comando gcloud, helm e kubectl utilizzati in questo tutorial.

  1. Vai alla consoleGoogle Cloud .

  2. Nella parte superiore della finestra della console Google Cloud , fai clic sul pulsante Attiva Cloud Shell Pulsante Attiva Cloud Shell.

    All'interno di un nuovo frame nella consoleGoogle Cloud si apre una sessione di Cloud Shell e viene visualizzato un prompt della riga di comando.

  3. Nel terminale, clona il repository kubernetes-engine-samples:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. Passa alla directory contenente i file di esempio:

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Crea e attiva un ambiente virtuale Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Installa l'interfaccia a riga di comando Ray:

    pip install "ray[default]==2.55.0"
    
  7. Imposta le seguenti variabili di ambiente:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    Sostituisci quanto segue:

    • GS_BUCKET: il nome del bucket Cloud Storage.
    • KSA_NAME: il nome del service account Kubernetes.
    • CLUSTER_NAME: il nome del nuovo cluster.
    • REGION: la regione in cui è disponibile la capacità TPU Trillium.
    • ZONE: la zona in cui è disponibile la capacità TPU Trillium. Per saperne di più, consulta Disponibilità delle TPU in GKE.

Configura il networking del cluster per Cloud TPU Multislice

All'interno di una sezione TPU multi-host, i dispositivi TPU comunicano tramite le interconnessioni inter-chip ad alta velocità. Tuttavia, quando esegui job multislice, gli slice TPU devono comunicare tra loro tramite la DCN. Le reti di pod Kubernetes standard possono limitare questo traffico. Il tipo di macchina ct6e-standard-4t è supportato da più schede di interfaccia di rete (NIC) fisiche. Per ottenere le migliori prestazioni, crea due reti VPC aggiuntive e utilizza GKE DRANET per connetterle direttamente ai pod Ray.

  1. Crea le due reti VPC aggiuntive con un'unità massima di trasmissione (MTU) di grandi dimensioni:

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. Crea le subnet dedicate:

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

Crea un cluster GKE

Puoi configurare KubeRay sulle TPU in un cluster GKE Autopilot o Standard. Ti consigliamo di utilizzare un cluster Autopilot per un'esperienza Kubernetes completamente gestita. Per scegliere la modalità operativa GKE più adatta ai tuoi workload, consulta Informazioni sulle modalità operative di GKE.

Per utilizzare DRANET gestito da GKE, il cluster deve utilizzare la versione 1.35.2-gke.1842000 o successive per la modalità Autopilot oppure la versione 1.34.1-gke.1829001 o successive per la modalità Standard. Questo tutorial utilizza la versione 1.35.2-gke.1842000.

Autopilot

  1. In Cloud Shell, esegui questo comando:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. Per comunicare con il cluster, configura kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

Standard

  1. In Cloud Shell, crea un cluster standard che attiva il componente aggiuntivo Operatore Ray eseguendo questo comando:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    Questo comando attiva anche GcsFuseCsiDriver, che consente ai pod di montare i bucket Cloud Storage come file system locali. La creazione del cluster potrebbe richiedere diversi minuti.

  2. Per comunicare con il cluster, configura kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. Crea il primo pool di nodi di sezioni TPU multi-host con GKE DRANET abilitato:

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. Crea il secondo pool di nodi di sezioni TPU:

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE esegue il provisioning di un pool di nodi costituito da quattro VM TPU Trillium (v6e), che vengono configurate insieme come una sezione TPU multi-host con una topologia 4x4. Questo pool di nodi è pronto per i carichi di lavoro di addestramento distribuito.

Il cluster GKE con operatore Ray abilitato installa automaticamente KubeRay e il webhook KubeRay TPU nel cluster.

Configurare un bucket Cloud Storage e un account di servizio

  1. Crea un bucket Cloud Storage per i checkpoint condivisi tra i nodi TPU multi-host.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Per abilitare l'accesso al bucket Cloud Storage, crea un service account Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Per abilitare l'accesso al bucket Cloud Storage, aggiungi i binding dei criteri IAM richiesti all'account di servizio:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Crea lo script di addestramento

Lo script maxtext_multi_slice_trainer.py utilizza JaxTrainer di Ray Train per eseguire un job di addestramento distribuito MaxText su due slice TPU. Lo script configura l'ambiente di addestramento per otto worker TPU multi-host ed esegue il job di addestramento MaxText su ogni nodo worker. La funzione train_loop_per_worker esegue il wrapping del punto di ingresso principale di MaxText e utilizza lo scheduler distribuito di Ray per eseguire il trainer MaxText su uno slice TPU multihost:

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

Lo script precedente definisce un'istanza di JaxTrainer che richiede otto worker e una topologia di 4x4. Internamente, Ray esegue il provisioning di un SlicePlacementGroup nelle due sezioni TPU e contribuisce a garantire che i worker Ray Train vengano eseguiti in modo atomico in entrambe le sezioni, con un worker per host.

Addestra il modello

  1. Il manifest ray-cluster.tpu-multi-slice.yaml nella directory corrente definisce la risorsa personalizzata RayCluster. Questo manifest include DRANET ResourceClaimTemplate per eseguire il provisioning dei dispositivi di rete per GKE DRANET e Multislice:

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    La specifica RayCluster precedente crea un gruppo di worker TPU con otto worker (numOfHosts: 4) per replica, con due repliche. Ogni worker richiede quattro chip TPU (google.com/tpu: "4"). I worker sono pianificati su un nodo TPU Trillium (tpu-v6e-slice), che fa parte della stessa slice multihost collocate. KubeRay esegue lo scale di tutti e quattro i worker in una sezione in modo atomico. Le variabili di ambiente JAX richieste, nonché le affinità dei pod per la pianificazione, vengono avviate da GKE tramite un webhook di mutazione.

  2. Per creare RayCluster, applica il manifest:

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. Verifica che il cluster sia pronto e in esecuzione:

    kubectl get rayclusters maxtext-tpu-cluster
    

    L'output dovrebbe essere simile al seguente:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. Per accedere alla dashboard Ray tramite il servizio head Ray, stabilisci una sessione di port forwarding:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifica che RayCluster sia raggiungibile dal tuo ambiente locale:

    ray list nodes --address http://localhost:8265
    

    L'output dovrebbe essere simile al seguente:

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. Scarica il file di configurazione di base MaxText. Questo file è richiesto dallo script di addestramento per impostare gli iperparametri predefiniti del modello:

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. Invia lo script JaxTrainer a RayCluster e verifica che RayJob venga completato correttamente:

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

Il comando precedente invia lo script Python, che chiama il codice JaxTrainer Ray a RayCluster. Il comando ray job submit include alcuni argomenti specifici di MaxText da passare alla configurazione del modello.

Nel terminale, dovresti vedere un output simile al seguente per il job Llama 3 70B:

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

Esegui l'addestramento elastico multislice sulle VM spot

Quando utilizzi acceleratori molto richiesti come le TPU, l'utilizzo di VM spot potrebbe ridurre significativamente i costi. Tuttavia, le VM spot potrebbero essere prerilasciate in modo imprevisto.

Ray Train supporta l'addestramento elastico, che consente al job di scalare dinamicamente il numero di slice TPU partecipanti verso l'alto o verso il basso senza errori. Se una sezione viene prerilasciata, Ray mette in pausa il loop di addestramento, attende che i worker rimanenti si riorganizzino, ripristina l'ultimo checkpoint MaxText e riprende l'addestramento con l'impronta più piccola.

Per attivare l'addestramento elastico, modifica il parametro num_workers nel file ScalingConfig da un numero intero statico a una tupla che rappresenta (minimum_workers, maximum_workers). Inoltre, aggiungi un FailureConfig(max_failures=3) a RunConfig, che indica a Ray Train di riprovare il ciclo di addestramento fino a 3 volte anziché interrompere completamente il job quando un worker viene interrotto.

Aggiorna lo script Ray Train

  1. Lo script maxtext_elastic_trainer.py nella directory corrente attiva l'addestramento elastico. Nota che imposta num_workers=(4,8), che indica a Ray di procedere se è disponibile almeno una slice da 16 chip (quattro worker), ma di fare lo scale up fino a due slice (otto worker) se possibile. Include un FailureConfig per attivare l'addestramento elastico, definire il numero di tentativi e contribuire a garantire che il job sopravviva ai preempt:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Invia il job utilizzando l'interfaccia a riga di comando di Ray Job. Assicurati di fornire un run_name univoco in modo che i checkpoint non siano in conflitto con le esecuzioni precedenti.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. Per simulare la terminazione o il prerilascio di un nodo durante l'addestramento, elimina un pod.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

Il terminale registra un errore del worker, ma il controller di orchestrazione mantiene il job attivo e riprende automaticamente dal checkpoint /data/rayjob-elastic-8b/checkpoints dopo che la topologia minima è disponibile.

Poiché MaxText ricalcola dinamicamente la mesh dei dispositivi alla ripresa, non è necessario scrivere alcuna logica personalizzata per gestire la ridistribuzione dei checkpoint quando la topologia si riduce. JAX's Orbax checkpointer eseguirà automaticamente il resharding dei pesi salvati nel nuovo layout fisico prima di continuare il ciclo di addestramento. L'output seguente mostra il controller Ray Train che rileva le risorse TPU appena disponibili nel cluster ed esegue un'operazione di scalabilità da una slice (quattro worker) a due slice (otto worker) durante l'addestramento.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

Esegui la pulizia

Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.

  1. Elimina RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Elimina il cluster GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Elimina il bucket Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    

Passaggi successivi