Addestra un LLM utilizzando JAX, Ray Train e TPU Trillium su GKE

Questo tutorial mostra come addestrare il modello linguistico di grandi dimensioni (LLM) Llama 3 8B su Google Kubernetes Engine (GKE) utilizzando MaxText, Ray Train e TPU.

Questo tutorial fornisce una procedura dettagliata completa end-to-end, dalla configurazione dell'infrastruttura cloud necessaria all'invio e all'esecuzione corretta del carico di lavoro di addestramento sulle TPU multi-host.

Questo tutorial è rivolto ad amministratori e operatori della piattaforma e a specialisti di dati e AI che vogliono imparare ad addestrare modelli di grandi dimensioni su uno slice TPU distribuito e multi-host.

Sfondo

La combinazione di GKE, KubeRay, MaxText e TPU fornisce una piattaforma potente e scalabile per l'addestramento di modelli su larga scala. Questa sezione descrive le tecnologie chiave utilizzate in questa guida:

JAX

JAX è una libreria Python per il calcolo di array e la trasformazione di programmi orientati agli acceleratori, progettata per il calcolo numerico ad alte prestazioni e il machine learning su larga scala.

JAX fornisce un sistema estensibile per trasformare funzioni numeriche come jax.grad, jax.jit e jax.vmap, utilizzando il compilatore XLA per creare codice altamente ottimizzato che viene scalato in modo efficiente su acceleratori come GPU e TPU. La potenza principale di JAX risiede nella sua composizione, che consente agli utenti di combinare queste trasformazioni per creare programmi numerici complessi e ad alte prestazioni per l'esecuzione distribuita.

MaxText

MaxText è un modello linguistico di grandi dimensioni (LLM) open source ad alte prestazioni progettato per la scalabilità e la personalizzazione. MaxText è basato su JAX e ottimizzato per essere eseguito in modo efficiente su Cloud TPU e GPU.

TPU

Le Tensor Processing Unit (TPU) sono acceleratori progettati su misura creati da Google per ottimizzare i carichi di lavoro di machine learning. A differenza delle CPU per uso generico o delle GPU per l'elaborazione parallela, le TPU sono altamente specializzate per i calcoli massicci di matrici e tensori alla base del deep learning, il che le rende efficienti in questa attività specifica. Il vantaggio principale delle TPU è il rendimento su larga scala.

Questo tutorial utilizza la TPU Trillium, che è la sesta generazione di TPU. Per saperne di più, consulta Vantaggi dell'utilizzo di TPU Trillium.

KubeRay

KubeRay è un operatore Kubernetes che fornisce un modo unificato per eseguire il deployment, gestire e monitorare le applicazioni Ray su Kubernetes. L'operatore KubeRay viene installato e gestito tramite il componente aggiuntivo Ray su GKE, che è il modo consigliato per eseguire il deployment e gestire i cluster Ray su GKE.

Obiettivi

Questo tutorial mostra come:

  1. Configura un cluster GKE con un pool di nodi TPU multi-host.
  2. Configura KubeRay per gestire l'ambiente di addestramento distribuito.
  3. Crea un'immagine Docker personalizzata che contenga le dipendenze di MaxText, Ray e JAX.
  4. Crea uno script di addestramento Python che utilizzi JaxTrainer di Ray Train per orchestrare il ciclo di addestramento MaxText nella sezione TPU.
  5. Definisci una risorsa personalizzata RayCluster per eseguire il provisioning dei nodi head e worker con le risorse TPU necessarie.
  6. Invia il job di addestramento a RayCluster e monitorane l'avanzamento.
  7. Utilizza Cloud Storage per archiviare i checkpoint del modello.

Prima di iniziare

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • Se utilizzi un provider di identità (IdP) esterno, devi prima accedere a gcloud CLI con la tua identità federata.

  • Per inizializzare gcloud CLI, esegui questo comando:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • Se utilizzi un provider di identità (IdP) esterno, devi prima accedere a gcloud CLI con la tua identità federata.

  • Per inizializzare gcloud CLI, esegui questo comando:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • Poiché questo tutorial utilizza TPU Trillium (v6e), seleziona una regione o una zona con disponibilità. Per maggiori informazioni, consulta la sezione Quote di Cloud TPU.

prepara l'ambiente

In questo tutorial utilizzerai Cloud Shell. Cloud Shell include gli strumenti a riga di comando gcloud, helm e kubectl utilizzati in questo tutorial.

  1. Vai alla consoleGoogle Cloud .

  2. Nella parte superiore della finestra della console Google Cloud , fai clic sul pulsante Attiva Cloud Shell Pulsante Attiva Cloud Shell.

    All'interno di un nuovo frame nella consoleGoogle Cloud si apre una sessione di Cloud Shell e viene visualizzato un prompt della riga di comando.

  3. Crea e attiva un ambiente virtuale Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Installa la CLI Ray e altre dipendenze:

    pip install "ray[default]==2.49.1"
    
  5. Imposta le seguenti variabili di ambiente:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    Sostituisci quanto segue:

    • GS_BUCKET: il nome del bucket Cloud Storage.
    • KSA_NAME: il nome del service account Kubernetes.
    • CLUSTER_NAME: il nome del nuovo cluster.
    • REGION: la regione in cui è disponibile la capacità TPU Trillium.
    • ZONE: la zona in cui è disponibile la capacità TPU Trillium. Per saperne di più, consulta Disponibilità di TPU in GKE.
    • ARTIFACT_REGISTRY: il nome del repository Artifact Registry.

Crea un cluster GKE

Puoi configurare KubeRay sulle TPU in un cluster GKE Autopilot o Standard. Ti consigliamo di utilizzare un cluster Autopilot per un'esperienza Kubernetes completamente gestita. Per scegliere la modalità operativa GKE più adatta ai tuoi workload, consulta Informazioni sulle modalità operative di GKE.

Autopilot

  1. In Cloud Shell, esegui questo comando:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. Per comunicare con il cluster, configura kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

Standard

  1. In Cloud Shell, crea un cluster standard che abiliti il componente aggiuntivo operatore Ray eseguendo questo comando:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    Questo comando attiva anche GcsFuseCsiDriver, che consente ai pod di montare i bucket Cloud Storage come file system locali. La creazione del cluster potrebbe richiedere diversi minuti.

  2. Per comunicare con il cluster, configura kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. Crea un pool di nodi TPU multi-host:

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE esegue il provisioning di un pool di nodi composto da quattro VM TPU Trillium (v6e), configurate insieme come una sezione di TPU multi-host, con una topologia 4x4, pronta per i workload di addestramento distribuito.

Il cluster GKE con operatore Ray abilitato installa automaticamente KubeRay e il webhook KubeRay TPU nel cluster.

Configurare un bucket Cloud Storage e un account di servizio

  1. Crea un bucket Cloud Storage per i checkpoint condivisi tra i nodi TPU multi-host.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Per abilitare l'accesso al bucket Cloud Storage, crea un service account Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Per abilitare l'accesso al bucket Cloud Storage, aggiungi i binding dei criteri IAM richiesti all'account di servizio:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Crea lo script di addestramento

Il seguente script utilizza JaxTrainer di Ray Train per eseguire un job di addestramento distribuito di MaxText. Lo script configura l'ambiente di addestramento per un pool di nodi di slice TPU multi-host ed esegue il job di addestramento MaxText su ogni nodo worker. La funzione train_loop_per_worker esegue il wrapping del punto di ingresso principale di MaxText e utilizza lo scheduler distribuito di Ray per eseguire il trainer MaxText su una sezione TPU multihost.

  1. Salva il seguente script Python come maxtext_ray_trainer.py:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. Per ospitare l'immagine personalizzata, crea un repository Artifact Registry:

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. Per creare un'immagine che includa le dipendenze di Ray e MaxText per l'addestramento, crea un Dockerfile:

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Crea, tagga ed esegui il push dell'immagine Docker in Artifact Registry:

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

Addestra il modello

  1. Salva il seguente manifest di esempio come maxtext-tpu-cluster.yaml:

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    La specifica RayCluster precedente crea un gruppo di worker TPU con quattro worker (numOfHosts: 4) per replica. Ogni worker richiede quattro chip TPU (google.com/tpu: "4"). I worker verranno pianificati su un nodo che esegue TPU Trillium (tpu-v6e-slice) e che fa parte della stessa slice multi-host collocate. KubeRay scala tutti e quattro i worker in modo atomico e le variabili di ambiente JAX richieste, nonché le affinità dei pod per la pianificazione, vengono avviate da GKE tramite un webhook di mutazione.

  2. Per configurare i valori richiesti nel file YAML, crea RayCluster utilizzando envsubst:

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. Verifica che il cluster sia pronto e in esecuzione:

    kubectl get rayclusters maxtext-tpu-cluster
    

    L'output dovrebbe essere simile al seguente:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Per accedere alla dashboard Ray tramite il servizio head Ray, stabilisci una sessione di port forwarding:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifica che RayCluster sia raggiungibile dal tuo ambiente locale:

    ray list nodes --address http://localhost:8265
    

    L'output dovrebbe essere simile al seguente:

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. Invia lo script JaxTrainer a RayCluster e verifica che RayJob venga completato correttamente:

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    Il comando precedente invia lo script Python, che chiama il codice JaxTrainer Ray a RayCluster. Il comando ray job submit include alcuni argomenti specifici di MaxText da passare alla configurazione del modello.

    Nel terminale dovresti vedere un output simile al seguente:

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

Esegui la pulizia

Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.

  1. Elimina RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Elimina il cluster GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Elimina il bucket Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Elimina il repository Artifact Registry:

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

Passaggi successivi