Addestra un LLM utilizzando JAX, Ray Train e TPU Trillium su GKE

Questo tutorial mostra come addestrare il modello linguistico di grandi dimensioni (LLM) Llama 3 8B su Google Kubernetes Engine (GKE) utilizzando MaxText, Ray Train e TPU.

Questo tutorial fornisce una procedura dettagliata completa end-to-end, dalla configurazione dell'infrastruttura cloud necessaria all'invio e all'esecuzione corretta del carico di lavoro di addestramento su TPU multi-host.

Questo tutorial è rivolto ad amministratori e operatori della piattaforma e a specialisti di dati e AI che vogliono imparare ad addestrare modelli di grandi dimensioni su una sezione TPU multi-host distribuita.

Sfondo

La combinazione di GKE, KubeRay, MaxText e TPU fornisce una piattaforma potente e scalabile per l'addestramento di modelli su larga scala. Questa sezione descrive le tecnologie chiave utilizzate in questa guida:

JAX

JAX è una libreria Python per il calcolo di array e la trasformazione di programmi orientati agli acceleratori, progettata per il calcolo numerico ad alte prestazioni e il machine learning su larga scala.

JAX fornisce un sistema estensibile per la trasformazione di funzioni numeriche come jax.grad, jax.jit e jax.vmap, utilizzando il compilatore XLA per creare codice altamente ottimizzato che si adatta in modo efficiente agli acceleratori come GPU e TPU. La potenza principale di JAX risiede nella sua componibilità, che consente agli utenti di combinare queste trasformazioni per creare programmi numerici complessi e ad alte prestazioni per l'esecuzione distribuita.

MaxText

MaxText è un modello linguistico di grandi dimensioni (LLM) open source ad alte prestazioni progettato per la scalabilità e la personalizzazione. MaxText è basato su JAX e ottimizzato per essere eseguito in modo efficiente su Cloud TPU e GPU.

TPU

Le TPU (Tensor Processing Unit) sono acceleratori progettati su misura creati da Google per ottimizzare i carichi di lavoro di machine learning. A differenza delle CPU per uso generico o delle GPU per l'elaborazione parallela, le TPU sono altamente specializzate per i calcoli di matrici e tensori di grandi dimensioni alla base del deep learning, il che le rende efficienti in questa attività specifica. Il vantaggio principale delle TPU è il rendimento su larga scala.

Questo tutorial utilizza TPU Trillium, la sesta generazione di TPU. Per saperne di più, consulta Vantaggi dell'utilizzo di TPU Trillium.

KubeRay

KubeRay è un operatore Kubernetes che fornisce un modo unificato per eseguire il deployment, gestire e monitorare le applicazioni Ray su Kubernetes. L'operatore KubeRay viene installato e gestito tramite il componente aggiuntivo Ray su GKE, che è il modo consigliato per eseguire il deployment e gestire i cluster Ray su GKE.

Obiettivi

Questo tutorial mostra gli aspetti seguenti:

  1. Configurare un cluster GKE con un pool di nodi TPU multi-host.
  2. Configurare KubeRay per gestire l'ambiente di addestramento distribuito.
  3. Creare un'immagine Docker personalizzata che contenga le dipendenze di MaxText, Ray e JAX.
  4. Creare uno script di addestramento Python che utilizzi JaxTrainer di Ray Train per orchestrare il loop di addestramento MaxText nella sezione TPU.
  5. Definire una RayCluster risorsa personalizzata per eseguire il provisioning dei nodi head e worker con le risorse TPU necessarie.
  6. Inviare il job di addestramento a RayCluster e monitorarne l'avanzamento.
  7. Utilizzare Cloud Storage per archiviare i checkpoint del modello.

Prima di iniziare

  • Accedi al tuo Google Cloud account. Se non conosci Google Cloud, crea un account per valutare le prestazioni dei nostri prodotti in scenari reali. I nuovi clienti ricevono anche 300 $di crediti senza costi per l'esecuzione, il test e il deployment dei carichi di lavoro.
  • Installa Google Cloud CLI.

  • Se utilizzi un provider di identità (IdP) esterno, devi prima accedere a gcloud CLI con la tua identità federata.

  • Per inizializzare gcloud CLI, esegui questo comando:

    gcloud init
  • Crea o seleziona un Google Cloud progetto.

    Ruoli richiesti per selezionare o creare un progetto

    • Seleziona un progetto: la selezione di un progetto non richiede un ruolo IAM specifico. Puoi selezionare qualsiasi progetto su cui ti è stato concesso un ruolo.
    • Crea un progetto: per creare un progetto, devi disporre del ruolo Autore progetto (roles/resourcemanager.projectCreator), che contiene l' resourcemanager.projects.create autorizzazione. Scopri come concedere i ruoli.
    • Crea un Google Cloud progetto:

      gcloud projects create PROJECT_ID

      Sostituisci PROJECT_ID con un nome per il Google Cloud progetto che stai creando.

    • Seleziona il Google Cloud progetto che hai creato:

      gcloud config set project PROJECT_ID

      Sostituisci PROJECT_ID con il nome del Google Cloud progetto.

  • Verifica che la fatturazione sia attivata per il tuo Google Cloud progetto.

  • Abilita le API richieste:

    Ruoli richiesti per abilitare le API

    Per abilitare le API, devi disporre del ruolo IAM Amministratore utilizzo servizi (roles/serviceusage.serviceUsageAdmin), che contiene l' serviceusage.services.enable autorizzazione. Scopri come concedere i ruoli.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Installa Google Cloud CLI.

  • Se utilizzi un provider di identità (IdP) esterno, devi prima accedere a gcloud CLI con la tua identità federata.

  • Per inizializzare gcloud CLI, esegui questo comando:

    gcloud init
  • Crea o seleziona un Google Cloud progetto.

    Ruoli richiesti per selezionare o creare un progetto

    • Seleziona un progetto: la selezione di un progetto non richiede un ruolo IAM specifico. Puoi selezionare qualsiasi progetto su cui ti è stato concesso un ruolo.
    • Crea un progetto: per creare un progetto, devi disporre del ruolo Autore progetto (roles/resourcemanager.projectCreator), che contiene l' resourcemanager.projects.create autorizzazione. Scopri come concedere i ruoli.
    • Crea un Google Cloud progetto:

      gcloud projects create PROJECT_ID

      Sostituisci PROJECT_ID con un nome per il Google Cloud progetto che stai creando.

    • Seleziona il Google Cloud progetto che hai creato:

      gcloud config set project PROJECT_ID

      Sostituisci PROJECT_ID con il nome del Google Cloud progetto.

  • Verifica che la fatturazione sia attivata per il tuo Google Cloud progetto.

  • Abilita le API richieste:

    Ruoli richiesti per abilitare le API

    Per abilitare le API, devi disporre del ruolo IAM Amministratore utilizzo servizi (roles/serviceusage.serviceUsageAdmin), che contiene l' serviceusage.services.enable autorizzazione. Scopri come concedere i ruoli.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Concedi i ruoli al tuo account utente. Esegui il seguente comando una volta per ciascuno dei seguenti ruoli IAM: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Sostituisci quanto segue:

    • PROJECT_ID: il tuo ID progetto.
    • USER_IDENTIFIER: l'identificatore del tuo account utente. Ad esempio, myemail@example.com.
    • ROLE: il ruolo IAM che concedi al tuo account utente.
  • Poiché questo tutorial utilizza TPU Trillium (v6e), seleziona una regione o una zona con disponibilità. Per saperne di più, consulta Quote di Cloud TPU.

Prepara l'ambiente

In questo tutorial utilizzi Cloud Shell. Cloud Shell è preinstallato con gli strumenti a riga di comando gcloud, helm e kubectl utilizzati in questo tutorial.

  1. Vai alla Google Cloud console.

  2. Nella parte superiore della finestra della Google Cloud console, fai clic sul pulsante Attiva Cloud Shell Pulsante Attiva Cloud Shell.

    All'interno di un nuovo frame nella Google Cloud console si apre una sessione di Cloud Shell e viene visualizzato un prompt della riga di comando.

  3. Crea e attiva un ambiente virtuale Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Installa l'interfaccia a riga di comando Ray e altre dipendenze:

    pip install "ray[default]==2.49.1"
    
  5. Imposta le seguenti variabili di ambiente:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    Sostituisci quanto segue:

    • GS_BUCKET: il nome del bucket Cloud Storage.
    • KSA_NAME: il nome del service account Kubernetes.
    • CLUSTER_NAME: il nome del nuovo cluster.
    • REGION: la regione in cui è disponibile la capacità di TPU Trillium.
    • ZONE: la zona in cui è disponibile la capacità di TPU Trillium. Per saperne di più, consulta Disponibilità di TPU in GKE.
    • ARTIFACT_REGISTRY: il nome del repository Artifact Registry.

Crea un cluster GKE

Puoi configurare KubeRay su TPU in un cluster GKE Autopilot o Standard. Ti consigliamo di utilizzare un cluster Autopilot per un'esperienza Kubernetes completamente gestita. Per scegliere la modalità operativa GKE più adatta ai tuoi carichi di lavoro, consulta Informazioni sulle modalità operative GKE.

Autopilot

  1. In Cloud Shell, esegui questo comando:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. Per comunicare con il cluster, configura kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

Standard

  1. In Cloud Shell, crea un cluster Standard che abiliti il componente aggiuntivo dell'operatore Ray eseguendo questo comando:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    Questo comando abilita anche GcsFuseCsiDriver, che consente ai pod di montare i bucket Cloud Storage come file system locali. La creazione del cluster potrebbe richiedere alcuni minuti.

  2. Per comunicare con il cluster, configura kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. Crea un pool di nodi di sezioni TPU multi-host:

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE esegue il provisioning di un pool di nodi composto da quattro VM TPU Trillium (v6e), configurate insieme come una sezione TPU multi-host, con una topologia 4x4, pronta per i carichi di lavoro di addestramento distribuiti.

Il cluster GKE abilitato all' operatore Ray installa automaticamente KubeRay e il webhook TPU KubeRay nel cluster.

Configura un bucket Cloud Storage e un account di servizio

  1. Crea un bucket Cloud Storage per i checkpoint condivisi tra i nodi TPU multi-host.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Per abilitare l'accesso al bucket Cloud Storage, crea un service account Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Per abilitare l'accesso al bucket Cloud Storage, aggiungi i binding della policy IAM richiesti al account di servizio:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Crea lo script di addestramento

Lo script seguente utilizza JaxTrainer di Ray Train per eseguire un job di addestramento MaxText distribuito. Lo script configura l'ambiente di addestramento per un pool di nodi di sezioni TPU multi-host ed esegue il job di addestramento MaxText su ogni nodo worker. La funzione train_loop_per_worker esegue il wrapping del punto di ingresso principale di MaxText e utilizza lo scheduler distribuito di Ray per eseguire il trainer MaxText su una sezione TPU multi-host.

  1. Salva il seguente script Python come maxtext_ray_trainer.py:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. Per ospitare l'immagine personalizzata, crea un repository Artifact Registry:

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. Per creare un'immagine che includa le dipendenze di Ray e MaxText per l'addestramento, crea un Dockerfile:

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Crea, tagga ed esegui il push dell'immagine Docker in Artifact Registry:

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

Addestra il modello

  1. Salva il seguente manifest di esempio come maxtext-tpu-cluster.yaml:

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    La specifica RayCluster precedente crea un gruppo di worker TPU con quattro worker (numOfHosts: 4) per replica. Ogni worker richiede quattro chip TPU (google.com/tpu: "4"). I worker verranno pianificati su un nodo che esegue TPU Trillium (tpu-v6e-slice) e che fa parte della stessa sezione multi-host colocalizzata. KubeRay scala tutti e quattro i worker in modo atomico e le variabili di ambiente JAX richieste, nonché le affinità dei pod per la pianificazione, vengono avviate da GKE tramite un webhook di modifica.

  2. Per configurare i valori richiesti nel file YAML, crea RayCluster utilizzando envsubst:

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. Verifica che il cluster sia pronto e in esecuzione:

    kubectl get rayclusters maxtext-tpu-cluster
    

    L'output dovrebbe essere simile al seguente:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Per accedere alla dashboard di Ray tramite il servizio head di Ray, stabilisci una sessione di port forwarding:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifica che RayCluster sia raggiungibile dal tuo ambiente locale:

    ray list nodes --address http://localhost:8265
    

    L'output dovrebbe essere simile al seguente:

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. Invia lo script JaxTrainer a RayCluster e verifica che RayJob venga completato correttamente:

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    Il comando precedente invia lo script Python, che chiama il codice Ray JaxTrainer a RayCluster. Il comando ray job submit include alcuni argomenti specifici di MaxText da passare alla configurazione del modello.

    Nel terminale dovresti vedere un output simile al seguente:

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

Libera spazio

Per evitare che al tuo Google Cloud account vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.

  1. Elimina RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Elimina il cluster GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Elimina il bucket Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Elimina il repository Artifact Registry:

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

Passaggi successivi