Entraîner un LLM à l'aide de JAX, Ray Train et TPU Trillium sur GKE

Ce tutoriel explique comment entraîner le grand modèle de langage (LLM) Llama 3 8B sur Google Kubernetes Engine (GKE) à l'aide de MaxText, Ray Train, et de TPU.

Ce tutoriel fournit une procédure complète de bout en bout, depuis la configuration de l'infrastructure cloud nécessaire jusqu'à l'envoi et l'exécution réussie de la charge de travail d'entraînement sur des TPU multi-hôtes.

Ce tutoriel s'adresse aux administrateurs et opérateurs de plate-forme, ainsi qu'aux spécialistes des données et de l'IA qui souhaitent apprendre à entraîner des modèles volumineux sur une tranche TPU distribuée et multi-hôte.

Arrière-plan

La combinaison de GKE, KubeRay, MaxText et des TPU fournit une plate-forme puissante et évolutive pour l'entraînement de modèles à grande échelle. Cette section décrit les principales technologies utilisées dans ce guide :

JAX

JAX est une bibliothèque Python pour le calcul de tableaux et la transformation de programmes orientés accélérateurs, conçue pour le calcul numérique hautes performances et le machine learning à grande échelle.

JAX fournit un système extensible pour transformer des fonctions numériques telles que jax.grad, jax.jit et jax.vmap, en utilisant le compilateur XLA pour créer un code hautement optimisé qui s'adapte efficacement aux accélérateurs tels que les GPU et les TPU. La puissance de base de JAX réside dans sa composabilité, qui permet aux utilisateurs de combiner ces transformations pour créer des programmes numériques complexes et hautes performances pour une exécution distribuée.

MaxText

MaxText est un grand modèle de langage (LLM) Open Source hautes performances conçu pour l'évolutivité et la personnalisation. MaxText est basé sur JAX et optimisé pour s'exécuter efficacement sur Cloud TPU et les GPU.

TPU

Les TPU (Tensor Processing Unit) sont des accélérateurs conçus sur mesure par Google pour optimiser les charges de travail de machine learning. Contrairement aux processeurs à usage général ou aux GPU de traitement parallèle, les TPU sont hautement spécialisés dans les calculs matriciels et tensoriels massifs qui sont à la base du deep learning, ce qui les rend efficaces pour cette tâche spécifique. Le principal avantage des TPU est leur performance à grande échelle.

Ce tutoriel utilise TPU Trillium, qui correspond à la sixième génération de TPU. Pour en savoir plus, consultez Avantages de l'utilisation de TPU Trillium.

KubeRay

KubeRay est un opérateur Kubernetes qui fournit un moyen unifié de déployer, de gérer et de surveiller les applications Ray sur Kubernetes. L'opérateur KubeRay est installé et géré via le module complémentaire Ray sur GKE, qui est le moyen recommandé de déployer et de gérer des clusters Ray sur GKE.

Objectifs

Ce tutoriel vous explique comment effectuer les tâches suivantes :

  1. Configurer un cluster GKE avec un pool de nœuds TPU multi-hôtes.
  2. Configurer KubeRay pour gérer l'environnement d'entraînement distribué.
  3. Créer une image Docker personnalisée contenant les dépendances MaxText, Ray et JAX.
  4. Créer un script d'entraînement Python qui utilise JaxTrainer de Ray Train pour orchestrer la boucle d'entraînement MaxText sur la tranche TPU.
  5. Définir une RayCluster ressource personnalisée pour provisionner les nœuds principaux et de calcul avec les ressources TPU nécessaires.
  6. Envoyer la tâche d'entraînement au RayCluster et surveiller sa progression.
  7. Utiliser Cloud Storage pour stocker les points de contrôle du modèle.

Avant de commencer

  • Connectez-vous à votre Google Cloud compte. Si vous débutez sur Google Cloud, créez un compte pour évaluer les performances de nos produits en conditions réelles. Les nouveaux clients bénéficient également de 300 $de crédits sans frais pour exécuter, tester et déployer des charges de travail.
  • Installez la Google Cloud CLI.

  • Si vous utilisez un fournisseur d'identité (IdP) externe, vous devez d'abord vous connecter à la gcloud CLI avec votre identité fédérée.

  • Pour initialiser la gcloud CLI, exécutez la commande suivante :

    gcloud init
  • Créez ou sélectionnez un Google Cloud projet.

    Rôles requis pour sélectionner ou créer un projet

    • Sélectionner un projet : la sélection d'un projet ne nécessite pas de rôle IAM spécifique Vous pouvez sélectionner n'importe quel projet pour lequel un rôle vous a été attribué.
    • Créer un projet : pour créer un projet, vous avez besoin du rôle Créateur de projet (roles/resourcemanager.projectCreator), qui contient l'autorisation resourcemanager.projects.create. Découvrez comment attribuer des rôles.
    • Créez un Google Cloud projet :

      gcloud projects create PROJECT_ID

      Remplacez PROJECT_ID par un nom pour le Google Cloud projet que vous créez.

    • Sélectionnez le Google Cloud projet que vous avez créé :

      gcloud config set project PROJECT_ID

      Remplacez PROJECT_ID par le nom de votre Google Cloud projet.

  • Vérifiez que la facturation est activée pour votre Google Cloud projet.

  • Activez les API requises :

    Rôles requis pour activer les API

    Pour activer les API, vous avez besoin du rôle IAM Administrateur d'utilisation du service (roles/serviceusage.serviceUsageAdmin), qui contient l' serviceusage.services.enable autorisation. Découvrez comment attribuer des rôles.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Installez la Google Cloud CLI.

  • Si vous utilisez un fournisseur d'identité (IdP) externe, vous devez d'abord vous connecter à la gcloud CLI avec votre identité fédérée.

  • Pour initialiser la gcloud CLI, exécutez la commande suivante :

    gcloud init
  • Créez ou sélectionnez un Google Cloud projet.

    Rôles requis pour sélectionner ou créer un projet

    • Sélectionner un projet : la sélection d'un projet ne nécessite pas de rôle IAM spécifique Vous pouvez sélectionner n'importe quel projet pour lequel un rôle vous a été attribué.
    • Créer un projet : pour créer un projet, vous avez besoin du rôle Créateur de projet (roles/resourcemanager.projectCreator), qui contient l'autorisation resourcemanager.projects.create. Découvrez comment attribuer des rôles.
    • Créez un Google Cloud projet :

      gcloud projects create PROJECT_ID

      Remplacez PROJECT_ID par un nom pour le Google Cloud projet que vous créez.

    • Sélectionnez le Google Cloud projet que vous avez créé :

      gcloud config set project PROJECT_ID

      Remplacez PROJECT_ID par le nom de votre Google Cloud projet.

  • Vérifiez que la facturation est activée pour votre Google Cloud projet.

  • Activez les API requises :

    Rôles requis pour activer les API

    Pour activer les API, vous avez besoin du rôle IAM Administrateur d'utilisation du service (roles/serviceusage.serviceUsageAdmin), qui contient l' serviceusage.services.enable autorisation. Découvrez comment attribuer des rôles.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Attribuez des rôles à votre compte utilisateur. Exécutez la commande suivante une fois pour chacun des rôles IAM suivants : roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Remplacez les éléments suivants :

    • PROJECT_ID : ID de votre projet
    • USER_IDENTIFIER : identifiant de votre compte utilisateur Par exemple, myemail@example.com.
    • ROLE : rôle IAM que vous attribuez à votre compte utilisateur
  • Étant donné que ce tutoriel utilise TPU Trillium (v6e), sélectionnez une région ou une zone où il est disponible. Pour en savoir plus, consultez Quotas Cloud TPU.

Préparer votre environnement

Dans ce tutoriel, vous utilisez Cloud Shell. Cloud Shell est préinstallé sur les outils de ligne de commande gcloud, helm et kubectl utilisés dans ce tutoriel.

  1. Accédez à la Google Cloud console.

  2. En haut de la fenêtre de la Google Cloud console, cliquez sur le bouton Activer Cloud Shell Bouton d'activation de Cloud Shell.

    Une session Cloud Shell s'ouvre dans un nouveau cadre dans la Google Cloud console et affiche une invite de ligne de commande.

  3. Créez et activez un environnement virtuel Python :

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Installez la CLI Ray et d'autres dépendances :

    pip install "ray[default]==2.49.1"
    
  5. Définissez les variables d'environnement suivantes :

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    Remplacez les éléments suivants :

    • GS_BUCKET: nom du bucket Cloud Storage
    • KSA_NAME: nom du compte de service Kubernetes
    • CLUSTER_NAME : nom du nouveau cluster
    • REGION: région où votre capacité TPU Trillium est disponible
    • ZONE: zone où votre capacité TPU Trillium est disponible Pour en savoir plus, consultez la section Disponibilité des TPU dans GKE.
    • ARTIFACT_REGISTRY : nom du dépôt Artifact Registry

Créer un cluster GKE

Vous pouvez configurer KubeRay sur des TPU dans un cluster GKE Autopilot ou GKE Standard. Nous vous recommandons d'utiliser un cluster GKE Autopilot pour une expérience Kubernetes entièrement gérée. Pour choisir le mode de fonctionnement GKE le mieux adapté à vos charges de travail, consultez À propos des modes de fonctionnement de GKE.

Autopilot

  1. Dans Cloud Shell, exécutez la commande suivante :

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. Pour communiquer avec votre cluster, configurez kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

Standard

  1. Dans Cloud Shell, créez un cluster GKE Standard qui active le module complémentaire de l'opérateur Ray en exécutant la commande suivante :

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    Cette commande active également GcsFuseCsiDriver, qui permet aux pods d'installer des buckets Cloud Storage en tant que systèmes de fichiers locaux. La création du cluster peut prendre plusieurs minutes.

  2. Pour communiquer avec votre cluster, configurez kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. Créez un pool de nœuds de tranche TPU multi-hôtes :

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE provisionne un pool de nœuds composé de quatre VM TPU Trillium (v6e), qui sont configurées ensemble en tant que tranche TPU multi-hôte, avec une topologie 4x4, prête pour les charges de travail d'entraînement distribué.

Le cluster GKE compatible avec l'opérateur Ray installe automatiquement KubeRay et le webhook KubeRay TPU dans votre cluster.

Configurer un bucket Cloud Storage et un compte de service

  1. Créez un bucket Cloud Storage pour les points de contrôle partagés entre les nœuds TPU multi-hôtes.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Pour activer l'accès au bucket Cloud Storage, créez un compte de service Kubernetes :

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Pour activer l'accès au bucket Cloud Storage, ajoutez les liaisons de stratégie IAM requises au compte de service :

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Créer un script d'entraînement

Le script suivant utilise JaxTrainer de Ray Train pour exécuter une tâche d'entraînement MaxText distribuée. Le script configure l'environnement d'entraînement pour un pool de nœuds de tranche TPU multi-hôtes et exécute la tâche d'entraînement MaxText sur chaque nœud de calcul. La fonction train_loop_per_worker encapsule le point d'entrée principal de MaxText et utilise le planificateur distribué de Ray pour exécuter l'entraîneur MaxText sur une tranche TPU multi-hôte.

  1. Enregistrez le script Python suivant sous le nom maxtext_ray_trainer.py :

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. Pour héberger l'image personnalisée, créez un dépôt Artifact Registry :

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. Pour créer une image incluant les dépendances Ray et MaxText pour l'entraînement, créez un Dockerfile :

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Créez, balisez et transmettez l'image Docker à Artifact Registry :

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

Entraîner le modèle

  1. Enregistrez l'exemple de fichier manifeste suivant sous le nom maxtext-tpu-cluster.yaml :

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    La spécification RayCluster précédente crée un groupe de nœuds de calcul TPU avec quatre nœuds de calcul (numOfHosts: 4) par instance répliquée. Chaque nœud de calcul demande quatre puces TPU (google.com/tpu: "4"). Les nœuds de calcul seront planifiés sur un nœud qui exécute TPU Trillium (tpu-v6e-slice) et qui fait partie de la même tranche multi-hôte colocalisée. KubeRay met à l'échelle les quatre nœuds de calcul de manière atomique, et les variables d'environnement JAX requises, ainsi que les affinités de pods pour la planification, sont amorcées par GKE via un webhook de mutation.

  2. Pour configurer les valeurs requises dans le fichier YAML, créez le RayCluster à l'aide de envsubst :

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. Vérifiez que le cluster est prêt et en cours d'exécution :

    kubectl get rayclusters maxtext-tpu-cluster
    

    La sortie devrait ressembler à ce qui suit :

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Pour accéder au tableau de bord Ray via le service principal Ray, établissez une session de transfert de port :

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Vérifiez que le RayCluster est accessible depuis votre environnement local :

    ray list nodes --address http://localhost:8265
    

    La sortie devrait ressembler à ce qui suit :

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. Envoyez le script JaxTrainer au RayCluster et vérifiez que le RayJob s'exécute correctement :

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    La commande précédente envoie le script Python, qui appelle le code Ray JaxTrainer au RayCluster. La commande ray job submit inclut des arguments spécifiques à MaxText à transmettre à la configuration du modèle.

    Dans votre terminal, vous devriez obtenir une sortie semblable à celle-ci :

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

Effectuer un nettoyage

Pour éviter que les ressources utilisées dans ce tutoriel ne soient facturées sur votre Google Cloud compte, supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles.

  1. Supprimez le RayCluster :

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Supprimez le cluster GKE :

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Supprimez le bucket Cloud Storage :

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Supprimez le dépôt Artifact Registry :

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

Étape suivante