Entraînement multislices et élastique sur des TPU à l'aide de Ray Train sur GKE

Ce tutoriel explique comment entraîner des grands modèles de langage (LLM) comme Llama 3 70B sur Google Kubernetes Engine (GKE) à l'aide de MaxText, Ray Train et des TPU Trillium multislices. Ce tutoriel fournit un guide complet de bout en bout, de la configuration de la mise en réseau du centre de données secondaire nécessaire à l'envoi et à l'exécution réussie d'une charge de travail d'entraînement distribuée sur 32 puces TPU physiques.

Ce tutoriel s'adresse aux administrateurs de plate-forme, aux opérateurs et aux spécialistes de l'IA qui souhaitent apprendre à surmonter les problèmes de mémoire et de réseau liés à l'entraînement de modèles à 70 milliards de paramètres sur des tranches de TPU distribuées et multi-hôtes.

Arrière-plan

La combinaison de GKE, KubeRay, MaxText et des TPU fournit une plate-forme puissante et évolutive pour l'entraînement de modèles à grande échelle. Cette section décrit les principales technologies utilisées dans ce guide :

JAX

JAX est une bibliothèque Python pour le calcul de tableaux et la transformation de programmes orientés accélérateur, qui utilise le compilateur XLA pour créer du code hautement optimisé qui s'adapte efficacement aux accélérateurs.

MaxText

MaxText est un framework LLM Open Source hautes performances conçu pour l'évolutivité et la personnalisation. MaxText est basé sur JAX et optimisé pour s'exécuter efficacement sur les Cloud TPU.

TPU

Les Tensor Processing Units (TPU) sont des accélérateurs conçus sur mesure par Google pour optimiser les charges de travail de machine learning. Contrairement aux CPU à usage général ou aux GPU à traitement parallèle, les TPU sont hautement spécialisés dans les calculs matriciels et tensoriels massifs qui sont à la base du deep learning, ce qui les rend efficaces pour cette tâche spécifique. L'avantage principal des TPU est leur capacité à offrir des performances à grande échelle.

Ce tutoriel utilise TPU Trillium, la sixième génération de TPU, dans un modèle de déploiement Multislice. Cloud TPU Multislice est un environnement dans lequel au moins deux tranches Cloud TPU communiquent sur le réseau du centre de données (DCN). Multislice permet un entraînement full stack à grande échelle et économique, avec un scaling presque linéaire jusqu'à plusieurs dizaines de milliers de puces TPU. Pour en savoir plus sur Multislice, consultez la présentation de Cloud TPU Multislice.

KubeRay

KubeRay est un opérateur Kubernetes qui fournit une méthode unifiée pour déployer, gérer et surveiller les applications Ray sur Kubernetes. L'opérateur KubeRay est installé et géré via le module complémentaire Ray sur GKE, qui est la méthode recommandée pour déployer et gérer les clusters Ray sur GKE.

GKE Dynamic Resource Allocation Network (DRANET)

GKE DRANET (Dynamic Resource Allocation Network) est une fonctionnalité qui associe de manière dynamique des périphériques réseau hautes performances aux pods, en contournant la mise en réseau Kubernetes standard et en permettant des performances élevées sur le DCN.

Objectifs

Ce tutoriel vous explique comment effectuer les tâches suivantes :

  1. Configurez un cluster GKE avec deux pools de nœuds TPU multi-hôtes.
  2. Configurez un DCN secondaire pour la communication TPU entre les tranches.
  3. Configurez KubeRay pour gérer l'environnement d'entraînement distribué.
  4. Déployez une ressource personnalisée RayCluster à l'aide de l'allocation dynamique de ressources (DRA) pour les pièces jointes réseau.
  5. Créez un script d'entraînement Python en utilisant JaxTrainer de Ray Train pour orchestrer la boucle d'entraînement MaxText sur les tranches de TPU.
  6. Exécutez une tâche d'entraînement de référence Llama 3 8B.
  7. Faites évoluer votre modèle jusqu'à Llama 3 70B en utilisant le sharding 2D (parallélisme de tenseur et FSDP) sur le DCN.

Avant de commencer

  • Connectez-vous à votre compte Google Cloud . Si vous débutez sur Google Cloud, créez un compte pour évaluer les performances de nos produits en conditions réelles. Les nouveaux clients bénéficient également de 300 $de crédits sans frais pour exécuter, tester et déployer des charges de travail.
  • Installez la Google Cloud CLI.

  • Si vous utilisez un fournisseur d'identité (IdP) externe, vous devez d'abord vous connecter à la gcloud CLI avec votre identité fédérée.

  • Pour initialiser la gcloud CLI, exécutez la commande suivante :

    gcloud init
  • Créez ou sélectionnez un projet Google Cloud .

    Rôles requis pour sélectionner ou créer un projet

    • Sélectionnez un projet : la sélection d'un projet ne nécessite pas de rôle IAM spécifique. Vous pouvez sélectionner n'importe quel projet pour lequel un rôle vous a été attribué.
    • Créer un projet : pour créer un projet, vous devez disposer du rôle Créateur de projet (roles/resourcemanager.projectCreator), qui contient l'autorisation resourcemanager.projects.create. Découvrez comment attribuer des rôles.
    • Créez un projet Google Cloud  :

      gcloud projects create PROJECT_ID

      Remplacez PROJECT_ID par le nom du projet Google Cloud que vous créez.

    • Sélectionnez le projet Google Cloud que vous avez créé :

      gcloud config set project PROJECT_ID

      Remplacez PROJECT_ID par le nom de votre projet Google Cloud .

  • Vérifiez que la facturation est activée pour votre projet Google Cloud .

  • Activez les API requises :

    Rôles requis pour activer les API

    Pour activer les API, vous avez besoin du rôle IAM Administrateur Service Usage (roles/serviceusage.serviceUsageAdmin), qui contient l'autorisation serviceusage.services.enable. Découvrez comment attribuer des rôles.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Installez la Google Cloud CLI.

  • Si vous utilisez un fournisseur d'identité (IdP) externe, vous devez d'abord vous connecter à la gcloud CLI avec votre identité fédérée.

  • Pour initialiser la gcloud CLI, exécutez la commande suivante :

    gcloud init
  • Créez ou sélectionnez un projet Google Cloud .

    Rôles requis pour sélectionner ou créer un projet

    • Sélectionnez un projet : la sélection d'un projet ne nécessite pas de rôle IAM spécifique. Vous pouvez sélectionner n'importe quel projet pour lequel un rôle vous a été attribué.
    • Créer un projet : pour créer un projet, vous devez disposer du rôle Créateur de projet (roles/resourcemanager.projectCreator), qui contient l'autorisation resourcemanager.projects.create. Découvrez comment attribuer des rôles.
    • Créez un projet Google Cloud  :

      gcloud projects create PROJECT_ID

      Remplacez PROJECT_ID par le nom du projet Google Cloud que vous créez.

    • Sélectionnez le projet Google Cloud que vous avez créé :

      gcloud config set project PROJECT_ID

      Remplacez PROJECT_ID par le nom de votre projet Google Cloud .

  • Vérifiez que la facturation est activée pour votre projet Google Cloud .

  • Activez les API requises :

    Rôles requis pour activer les API

    Pour activer les API, vous avez besoin du rôle IAM Administrateur Service Usage (roles/serviceusage.serviceUsageAdmin), qui contient l'autorisation serviceusage.services.enable. Découvrez comment attribuer des rôles.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Attribuez des rôles à votre compte utilisateur. Exécutez la commande suivante une fois pour chacun des rôles IAM suivants : roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Remplacez les éléments suivants :

    • PROJECT_ID : ID de votre projet
    • USER_IDENTIFIER : identifiant de votre compte d'utilisateur. Par exemple, myemail@example.com.
    • ROLE : rôle IAM que vous accordez à votre compte utilisateur.
  • Étant donné que ce tutoriel utilise des TPU Trillium (v6e), sélectionnez une région ou une zone où ils sont disponibles. Pour en savoir plus, consultez Quotas Cloud TPU.

Préparer votre environnement

Dans ce tutoriel, vous utilisez Cloud Shell. Cloud Shell est préinstallé sur les outils de ligne de commande gcloud, helm et kubectl utilisés dans ce tutoriel.

  1. Accédez à la consoleGoogle Cloud .

  2. En haut de la fenêtre de la console Google Cloud , cliquez sur le bouton Activer Cloud Shell Bouton d'activation de Cloud Shell.

    Une session Cloud Shell s'ouvre dans un nouveau cadre dans la console Google Cloud et affiche une invite de ligne de commande.

  3. Dans votre terminal, clonez le dépôt kubernetes-engine-samples :

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. Accédez au répertoire contenant les fichiers d'exemple :

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Créez et activez un environnement virtuel Python :

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Installez la CLI Ray :

    pip install "ray[default]==2.55.0"
    
  7. Définissez les variables d'environnement suivantes :

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    Remplacez les éléments suivants :

    • GS_BUCKET : nom du bucket Cloud Storage.
    • KSA_NAME : nom du compte de service Kubernetes.
    • CLUSTER_NAME : nom du nouveau cluster
    • REGION : région dans laquelle votre capacité de TPU Trillium est disponible.
    • ZONE : zone dans laquelle votre capacité de TPU Trillium est disponible. Pour en savoir plus, consultez la section Disponibilité des TPU dans GKE.

Configurer la mise en réseau du cluster pour Cloud TPU Multislice

Dans une tranche de TPU multi-hôte, les appareils TPU communiquent via des interconnexions à haut débit entre les puces. Toutefois, lorsque vous exécutez des jobs Multislice, les tranches de TPU doivent communiquer entre elles sur le DCN. Les réseaux de pods Kubernetes standards peuvent limiter ce trafic. Le type de machine ct6e-standard-4t est associé à plusieurs cartes d'interface réseau (NIC) physiques. Pour obtenir les meilleures performances, créez deux réseaux VPC supplémentaires et utilisez GKE DRANET pour les connecter directement aux pods Ray.

  1. Créez les deux réseaux VPC supplémentaires avec une unité de transmission maximale (MTU) élevée :

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. Créez les sous-réseaux dédiés :

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

Créer un cluster GKE

Vous pouvez configurer KubeRay sur des TPU dans un cluster GKE Autopilot ou Standard. Nous vous recommandons d'utiliser un cluster Autopilot pour une expérience Kubernetes entièrement gérée. Pour choisir le mode de fonctionnement GKE le mieux adapté à vos charges de travail, consultez À propos des modes de fonctionnement GKE.

Pour utiliser DRANET géré par GKE, votre cluster doit utiliser la version 1.35.2-gke.1842000 ou ultérieure pour le mode Autopilot, ou la version 1.34.1-gke.1829001 ou ultérieure pour le mode Standard. Ce tutoriel utilise la version 1.35.2-gke.1842000.

Autopilot

  1. Dans Cloud Shell, exécutez la commande suivante :

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. Pour communiquer avec votre cluster, configurez kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

Standard

  1. Dans Cloud Shell, créez un cluster Standard qui active le module complémentaire Ray Operator en exécutant la commande suivante :

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    Cette commande active également GcsFuseCsiDriver, ce qui permet aux pods d'installer des buckets Cloud Storage en tant que systèmes de fichiers locaux. La création du cluster peut prendre plusieurs minutes.

  2. Pour communiquer avec votre cluster, configurez kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. Créez le premier pool de nœuds de tranche TPU multi-hôte avec GKE DRANET activé :

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. Créez le deuxième pool de nœuds de tranche TPU :

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE provisionne un pool de nœuds composé de quatre VM TPU Trillium (v6e), qui sont configurées ensemble en tant que tranche TPU multi-hôte avec une topologie 4x4. Ce pool de nœuds est prêt pour les charges de travail d'entraînement distribué.

Le cluster GKE sur lequel le Ray Operator est activé installe automatiquement KubeRay et le webhook KubeRay TPU dans votre cluster.

Configurer un bucket Cloud Storage et un compte de service

  1. Créez un bucket Cloud Storage pour les points de contrôle partagés entre les nœuds TPU multihôtes.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Pour activer l'accès au bucket Cloud Storage, créez un compte de service Kubernetes :

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Pour activer l'accès au bucket Cloud Storage, ajoutez les liaisons de stratégie IAM requises au compte de service :

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Créer un script d'entraînement

Le script maxtext_multi_slice_trainer.py utilise JaxTrainer de Ray Train pour exécuter un job d'entraînement MaxText distribué sur deux tranches de TPU. Le script configure l'environnement d'entraînement pour huit nœuds de calcul TPU multi-hôtes et exécute le job d'entraînement MaxText sur chaque nœud de calcul. La fonction train_loop_per_worker encapsule le point d'entrée principal de MaxText et utilise le planificateur distribué de Ray pour exécuter l'outil d'entraînement MaxText sur une tranche TPU multi-hôtes :

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

Le script précédent définit une instance JaxTrainer demandant huit workers et une topologie de 4x4. En interne, Ray provisionne un SlicePlacementGroup sur les deux tranches de TPU et permet de s'assurer que les nœuds de calcul Ray Train s'exécutent de manière atomique sur les deux tranches, avec un nœud de calcul par hôte.

Entraîner le modèle

  1. Le fichier manifeste ray-cluster.tpu-multi-slice.yaml du répertoire actuel définit la ressource personnalisée RayCluster. Ce fichier manifeste inclut le DRANET ResourceClaimTemplate pour provisionner les périphériques réseau pour GKE DRANET et Multislice :

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    La spécification RayCluster précédente crée un groupe de nœuds de calcul TPU avec huit nœuds de calcul (numOfHosts: 4) par réplica, avec deux réplicas. Chaque nœud de calcul demande quatre puces TPU (google.com/tpu: "4"). Les nœuds de calcul sont chacun planifiés sur un nœud TPU Trillium (tpu-v6e-slice), qui fait partie de la même tranche multi-hôte colocalisée. KubeRay met à l'échelle les quatre nœuds de calcul d'une tranche de manière atomique. Les variables d'environnement JAX requises, ainsi que les affinités de pod pour la planification, sont amorcées par GKE via un webhook en mutation.

  2. Pour créer le RayCluster, appliquez le fichier manifeste :

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. Vérifiez que le cluster est prêt et en cours d'exécution :

    kubectl get rayclusters maxtext-tpu-cluster
    

    La sortie devrait ressembler à ce qui suit :

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. Pour accéder au tableau de bord Ray via le service principal Ray, établissez une session de transfert de port :

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Vérifiez que le RayCluster est accessible depuis votre environnement local :

    ray list nodes --address http://localhost:8265
    

    La sortie devrait ressembler à ce qui suit :

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. Téléchargez le fichier de configuration MaxText de base. Ce fichier est requis par le script d'entraînement pour définir les hyperparamètres par défaut du modèle :

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. Envoyez le script JaxTrainer au RayCluster et vérifiez que le RayJob s'est terminé correctement :

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

La commande précédente envoie le script Python, qui appelle le code JaxTrainer Ray au RayCluster. La commande ray job submit inclut des arguments spécifiques à MaxText à transmettre à la configuration du modèle.

Dans votre terminal, un résultat semblable à celui-ci doit s'afficher pour le job Llama 3 70B :

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

Exécuter un entraînement élastique Multislice sur des VM Spot

Lorsque vous utilisez des accélérateurs très demandés comme les TPU, l'utilisation de VM Spot peut réduire considérablement les coûts. Toutefois, les VM Spot peuvent être préemptées de manière inattendue.

Ray Train est compatible avec l'entraînement élastique, qui permet à votre job de faire évoluer dynamiquement le nombre de tranches TPU participantes à la hausse ou à la baisse sans échouer. Si un slice est préempté, Ray met en pause la boucle d'entraînement, attend que les nœuds de calcul restants se réorganisent, restaure le dernier point de contrôle MaxText et reprend l'entraînement sur l'empreinte plus petite.

Pour activer l'entraînement élastique, remplacez le paramètre num_workers dans votre ScalingConfig par un tuple représentant (minimum_workers, maximum_workers) au lieu d'un entier statique. Ajoutez également un FailureConfig(max_failures=3) au RunConfig, qui indique à Ray Train de réessayer la boucle d'entraînement jusqu'à trois fois au lieu de mettre en échec l'intégralité de la tâche lorsqu'un nœud de calcul est préempté.

Mettre à jour le script Ray Train

  1. Le script maxtext_elastic_trainer.py du répertoire actuel permet l'entraînement élastique. Notez qu'il définit num_workers=(4,8), ce qui indique à Ray de procéder si au moins une tranche de 16 puces (quatre nœuds de calcul) est disponible, mais de passer à deux tranches (huit nœuds de calcul) si possible. Il inclut un FailureConfig pour activer l'entraînement élastique, définir le nombre de tentatives et s'assurer que le job survit aux préemptions :

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Envoyez le job à l'aide de la CLI Ray Job. Veillez à fournir un run_name unique afin que les points de contrôle n'entrent pas en conflit avec les exécutions précédentes.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. Pour simuler l'arrêt ou la préemption d'un nœud pendant l'entraînement, supprimez un pod.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

Le terminal enregistre un échec du nœud de calcul, mais le contrôleur d'orchestration maintient le job actif et le reprend automatiquement à partir du point de contrôle /data/rayjob-elastic-8b/checkpoints une fois que la topologie minimale est disponible.

Étant donné que MaxText recalcule dynamiquement le maillage d'appareils lors de la reprise, vous n'avez pas besoin d'écrire de logique personnalisée pour gérer le re-sharding des points de contrôle lorsque la topologie se réduit. Le point de contrôle Orbax de JAX re-shardera automatiquement les poids enregistrés dans la nouvelle disposition physique avant de poursuivre la boucle d'entraînement. La sortie suivante montre que le contrôleur Ray Train détecte les ressources TPU nouvellement disponibles dans le cluster et effectue une opération de scaling d'un slice (quatre nœuds de calcul) à deux slices (huit nœuds de calcul) pendant l'entraînement.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

Effectuer un nettoyage

Pour éviter que les ressources utilisées dans ce tutoriel ne soient facturées sur votre compte Google Cloud , supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles.

  1. Supprimez le RayCluster :

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Supprimez le cluster GKE :

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Supprimez le bucket Cloud Storage :

    gsutil rm -r gs://${GS_BUCKET}
    

Étapes suivantes