Entraîner un LLM à l'aide de JAX, Ray Train et TPU Trillium sur GKE

Ce tutoriel vous explique comment entraîner le grand modèle de langage (LLM) Llama 3 8B sur Google Kubernetes Engine (GKE) à l'aide de MaxText, Ray Train et des TPU.

Ce tutoriel fournit un guide complet de bout en bout, de la configuration de l'infrastructure cloud nécessaire à l'envoi et à l'exécution réussie de la charge de travail d'entraînement sur des TPU multi-hôtes.

Ce tutoriel s'adresse aux administrateurs et opérateurs de plate-forme, ainsi qu'aux spécialistes des données et de l'IA qui souhaitent apprendre à entraîner de grands modèles sur une tranche de TPU distribuée et multi-hôte.

Arrière-plan

La combinaison de GKE, KubeRay, MaxText et des TPU fournit une plate-forme puissante et évolutive pour l'entraînement de modèles à grande échelle. Cette section décrit les principales technologies utilisées dans ce guide :

JAX

JAX est une bibliothèque Python pour le calcul de tableaux et la transformation de programmes orientés accélérateur, conçue pour le calcul numérique hautes performances et le machine learning à grande échelle.

JAX fournit un système extensible pour transformer des fonctions numériques telles que jax.grad, jax.jit et jax.vmap, en utilisant le compilateur XLA pour créer du code hautement optimisé qui s'adapte efficacement aux accélérateurs tels que les GPU et les TPU. La puissance de JAX réside dans sa composabilité, qui permet aux utilisateurs de combiner ces transformations pour créer des programmes numériques complexes et performants pour l'exécution distribuée.

MaxText

MaxText est un grand modèle de langage (LLM) Open Source hautes performances conçu pour l'évolutivité et la personnalisation. MaxText est basé sur JAX et optimisé pour s'exécuter efficacement sur les Cloud TPU et les GPU.

TPU

Les Tensor Processing Units (TPU) sont des accélérateurs conçus sur mesure par Google pour optimiser les charges de travail de machine learning. Contrairement aux CPU à usage général ou aux GPU à traitement parallèle, les TPU sont hautement spécialisés dans les calculs matriciels et tensoriels massifs qui sont à la base du deep learning, ce qui les rend efficaces pour cette tâche spécifique. L'avantage principal des TPU est leur capacité à offrir des performances à grande échelle.

Ce tutoriel utilise le TPU Trillium, qui est la sixième génération de TPU. Pour en savoir plus, consultez Avantages de l'utilisation de TPU Trillium.

KubeRay

KubeRay est un opérateur Kubernetes qui fournit une méthode unifiée pour déployer, gérer et surveiller les applications Ray sur Kubernetes. L'opérateur KubeRay est installé et géré via le module complémentaire Ray sur GKE, qui est la méthode recommandée pour déployer et gérer les clusters Ray sur GKE.

Objectifs

Ce tutoriel vous explique comment effectuer les tâches suivantes :

  1. Configurez un cluster GKE avec un pool de nœuds TPU multi-hôtes.
  2. Configurez KubeRay pour gérer l'environnement d'entraînement distribué.
  3. Créez une image Docker personnalisée contenant les dépendances MaxText, Ray et JAX.
  4. Créez un script d'entraînement Python qui utilise JaxTrainer de Ray Train pour orchestrer la boucle d'entraînement MaxText dans la tranche de TPU.
  5. Définissez une ressource personnalisée RayCluster pour provisionner les nœuds principaux et de calcul avec les ressources TPU nécessaires.
  6. Envoyez le job d'entraînement à RayCluster et surveillez sa progression.
  7. Utilisez Cloud Storage pour stocker les points de contrôle du modèle.

Avant de commencer

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • Si vous utilisez un fournisseur d'identité (IdP) externe, vous devez d'abord vous connecter à la gcloud CLI avec votre identité fédérée.

  • Pour initialiser la gcloud CLI, exécutez la commande suivante :

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • Si vous utilisez un fournisseur d'identité (IdP) externe, vous devez d'abord vous connecter à la gcloud CLI avec votre identité fédérée.

  • Pour initialiser la gcloud CLI, exécutez la commande suivante :

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • Étant donné que ce tutoriel utilise des TPU Trillium (v6e), sélectionnez une région ou une zone où ils sont disponibles. Pour en savoir plus, consultez Quotas Cloud TPU.

Préparer votre environnement

Dans ce tutoriel, vous utilisez Cloud Shell. Cloud Shell est préinstallé sur les outils de ligne de commande gcloud, helm et kubectl utilisés dans ce tutoriel.

  1. Accédez à la consoleGoogle Cloud .

  2. En haut de la fenêtre de la console Google Cloud , cliquez sur le bouton Activer Cloud Shell Bouton d'activation de Cloud Shell.

    Une session Cloud Shell s'ouvre dans un nouveau cadre dans la console Google Cloud et affiche une invite de ligne de commande.

  3. Créez et activez un environnement virtuel Python :

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Installez la CLI Ray et les autres dépendances :

    pip install "ray[default]==2.49.1"
    
  5. Définissez les variables d'environnement suivantes :

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    Remplacez les éléments suivants :

    • GS_BUCKET : nom du bucket Cloud Storage.
    • KSA_NAME : nom du compte de service Kubernetes.
    • CLUSTER_NAME : nom du nouveau cluster
    • REGION : région dans laquelle votre capacité de TPU Trillium est disponible.
    • ZONE : zone dans laquelle votre capacité TPU Trillium est disponible. Pour en savoir plus, consultez la section Disponibilité des TPU dans GKE.
    • ARTIFACT_REGISTRY : nom du dépôt Artifact Registry.

Créer un cluster GKE

Vous pouvez configurer KubeRay sur des TPU dans un cluster GKE Autopilot ou Standard. Nous vous recommandons d'utiliser un cluster Autopilot pour une expérience Kubernetes entièrement gérée. Pour choisir le mode de fonctionnement GKE le mieux adapté à vos charges de travail, consultez À propos des modes de fonctionnement GKE.

Autopilot

  1. Dans Cloud Shell, exécutez la commande suivante :

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. Pour communiquer avec votre cluster, configurez kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

Standard

  1. Dans Cloud Shell, créez un cluster Standard qui active le module complémentaire Ray Operator en exécutant la commande suivante :

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    Cette commande active également GcsFuseCsiDriver, ce qui permet aux pods d'installer des buckets Cloud Storage en tant que systèmes de fichiers locaux. La création du cluster peut prendre plusieurs minutes.

  2. Pour communiquer avec votre cluster, configurez kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. Créez un pool de nœuds de tranche TPU multi-hôte :

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE provisionne un pool de nœuds composé de quatre VM TPU Trillium (v6e), qui sont configurées ensemble en tant que tranche TPU multi-hôte, avec une topologie 4x4, prête pour les charges de travail d'entraînement distribué.

Le cluster GKE sur lequel le Ray Operator est activé installe automatiquement KubeRay et le webhook KubeRay TPU dans votre cluster.

Configurer un bucket Cloud Storage et un compte de service

  1. Créez un bucket Cloud Storage pour les points de contrôle partagés entre les nœuds TPU multihôtes.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Pour activer l'accès au bucket Cloud Storage, créez un compte de service Kubernetes :

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Pour activer l'accès au bucket Cloud Storage, ajoutez les liaisons de stratégie IAM requises au compte de service :

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Créer un script d'entraînement

Le script suivant utilise JaxTrainer de Ray Train pour exécuter un job d'entraînement MaxText distribué. Le script configure l'environnement d'entraînement pour un pool de nœuds de tranche TPU multi-hôtes et exécute le job d'entraînement MaxText sur chaque nœud de calcul. La fonction train_loop_per_worker encapsule le point d'entrée principal de MaxText et utilise le planificateur distribué de Ray pour exécuter l'entraîneur MaxText sur une tranche TPU multihôte.

  1. Enregistrez le script Python suivant sous le nom maxtext_ray_trainer.py :

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. Pour héberger l'image personnalisée, créez un dépôt Artifact Registry :

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. Pour créer une image incluant les dépendances Ray et MaxText pour l'entraînement, créez un fichier Dockerfile :

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Créez l'image Docker, ajoutez-y un tag et transférez-la vers Artifact Registry :

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

Entraîner le modèle

  1. Enregistrez l'exemple de fichier manifeste suivant sous le nom maxtext-tpu-cluster.yaml :

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    La spécification RayCluster précédente crée un groupe de nœuds de calcul TPU avec quatre nœuds de calcul (numOfHosts: 4) par réplica. Chaque nœud de calcul demande quatre puces TPU (google.com/tpu: "4"). Les nœuds de calcul seront planifiés sur un nœud qui exécute TPU Trillium (tpu-v6e-slice) et qui fait partie de la même tranche multi-hôte colocalisée. KubeRay met à l'échelle les quatre nœuds de calcul de manière atomique. Les variables d'environnement JAX requises, ainsi que les affinités de pod pour la planification, sont amorcées par GKE via un webhook de mutation.

  2. Pour configurer les valeurs requises dans le fichier YAML, créez le RayCluster à l'aide de envsubst :

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. Vérifiez que le cluster est prêt et en cours d'exécution :

    kubectl get rayclusters maxtext-tpu-cluster
    

    La sortie devrait ressembler à ce qui suit :

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Pour accéder au tableau de bord Ray via le service principal Ray, établissez une session de transfert de port :

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Vérifiez que le RayCluster est accessible depuis votre environnement local :

    ray list nodes --address http://localhost:8265
    

    La sortie devrait ressembler à ce qui suit :

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. Envoyez le script JaxTrainer au RayCluster et vérifiez que le RayJob s'est terminé correctement :

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    La commande précédente envoie le script Python, qui appelle le code JaxTrainer Ray au RayCluster. La commande ray job submit inclut des arguments spécifiques à MaxText à transmettre à la configuration du modèle.

    Dans votre terminal, vous devriez obtenir un résultat semblable à celui-ci :

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

Effectuer un nettoyage

Pour éviter que les ressources utilisées dans ce tutoriel ne soient facturées sur votre compte Google Cloud , supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles.

  1. Supprimez le RayCluster :

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Supprimez le cluster GKE :

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Supprimez le bucket Cloud Storage :

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Supprimez le dépôt Artifact Registry :

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

Étapes suivantes