Entrenar un LLM con JAX, Ray Train y TPU Trillium en GKE

En este tutorial se muestra cómo entrenar el modelo de lenguaje grande (LLM) Llama 3 8B en Google Kubernetes Engine (GKE) con MaxText, Ray Train y TPUs.

En este tutorial se explica todo el proceso, desde la configuración de la infraestructura en la nube necesaria hasta el envío y la ejecución correcta de la carga de trabajo de entrenamiento en las TPUs de varios hosts.

Este tutorial está dirigido a administradores y operadores de la plataforma, así como a especialistas en datos e IA que quieran aprender a entrenar modelos grandes en un slice de TPU distribuido y multihost.

Fondo

La combinación de GKE, KubeRay, MaxText y las TPUs proporciona una plataforma potente y escalable para el entrenamiento de modelos a gran escala. En esta sección se describen las tecnologías clave que se usan en esta guía:

JAX

JAX es una biblioteca de Python para la computación de arrays orientada a aceleradores y la transformación de programas, diseñada para la computación numérica de alto rendimiento y el aprendizaje automático a gran escala.

JAX proporciona un sistema extensible para transformar funciones numéricas como jax.grad, jax.jit y jax.vmap. Utiliza el compilador XLA para crear código altamente optimizado que se adapta de forma eficiente a aceleradores como GPUs y TPUs. La principal ventaja de JAX es su capacidad de composición, que permite a los usuarios combinar estas transformaciones para crear programas numéricos complejos y de alto rendimiento para la ejecución distribuida.

MaxText

MaxText es un modelo de lenguaje extenso (LLM) de código abierto y alto rendimiento diseñado para ofrecer escalabilidad y personalización. MaxText se basa en JAX y se ha optimizado para ejecutarse de forma eficiente en las TPU y las GPUs de Cloud.

TPUs

Las unidades de procesamiento de tensor (TPUs) son aceleradores diseñados a medida por Google para optimizar las cargas de trabajo de aprendizaje automático. A diferencia de las CPUs de uso general o las GPUs de procesamiento paralelo, las TPUs están altamente especializadas en las enormes computaciones de matrices y tensores que constituyen la base del aprendizaje profundo, lo que las hace eficientes en esta tarea específica. La principal ventaja de las TPUs es el rendimiento a gran escala.

En este tutorial se usa la TPU Trillium, que es la sexta generación de TPUs. Para obtener más información, consulta Ventajas de usar la TPU Trillium.

KubeRay

KubeRay es un operador de Kubernetes que proporciona una forma unificada de desplegar, gestionar y monitorizar aplicaciones de Ray en Kubernetes. El operador KubeRay se instala y gestiona a través del complemento Ray on GKE, que es la forma recomendada de desplegar y gestionar clústeres de Ray en GKE.

Objetivos

En este tutorial te explicamos cómo hacer lo siguiente:

  1. Configura un clúster de GKE con un grupo de nodos de TPU de varios hosts.
  2. Configura KubeRay para gestionar el entorno de entrenamiento distribuido.
  3. Crea una imagen de Docker personalizada que contenga las dependencias de MaxText, Ray y JAX.
  4. Crea una secuencia de comandos de entrenamiento de Python que use JaxTrainer de Ray Train para orquestar el bucle de entrenamiento de MaxText en el segmento de TPU.
  5. Define un RayCluster recurso personalizado para aprovisionar los nodos principales y de trabajo con los recursos de TPU necesarios.
  6. Envía el trabajo de entrenamiento a RayCluster y monitoriza su progreso.
  7. Usa Cloud Storage para almacenar los puntos de control del modelo.

Antes de empezar

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • Si utilizas un proveedor de identidades (IdP) externo, primero debes iniciar sesión en la CLI de gcloud con tu identidad federada.

  • Para inicializar gcloud CLI, ejecuta el siguiente comando:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • Si utilizas un proveedor de identidades (IdP) externo, primero debes iniciar sesión en la CLI de gcloud con tu identidad federada.

  • Para inicializar gcloud CLI, ejecuta el siguiente comando:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • Como en este tutorial se utiliza la TPU Trillium (v6e), selecciona una región o una zona que esté disponible. Para obtener más información, consulta las cuotas de Cloud TPU.

Prepara tu entorno

En este tutorial, usarás Cloud Shell. Cloud Shell tiene preinstaladas las herramientas de línea de comandos gcloud, helm y kubectl que se usan en este tutorial.

  1. Ve a la Google Cloud consola.

  2. En la parte superior de la ventana de la consola Google Cloud , haz clic en el botón Activar Cloud Shell Activar Shell
Button.

    Se abrirá una sesión de Cloud Shell dentro de un nuevo marco en laGoogle Cloud consola y se mostrará en ella un mensaje de la línea de comandos.

  3. Crea y activa un entorno virtual de Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Instala la CLI de Ray y otras dependencias:

    pip install "ray[default]==2.49.1"
    
  5. Define las siguientes variables de entorno:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    Haz los cambios siguientes:

    • GS_BUCKET: el nombre del segmento de Cloud Storage.
    • KSA_NAME: el nombre de la cuenta de servicio de Kubernetes.
    • CLUSTER_NAME: el nombre del nuevo clúster.
    • REGION: la región en la que está disponible tu capacidad de TPU Trillium.
    • ZONE: la zona en la que está disponible tu capacidad de TPU Trillium. Para obtener más información, consulta Disponibilidad de las TPU en GKE.
    • ARTIFACT_REGISTRY: nombre del repositorio de Artifact Registry.

Crear un clúster de GKE

Puedes configurar KubeRay en TPUs en un clúster Autopilot o Standard de GKE. Te recomendamos que uses un clúster de Autopilot para disfrutar de una experiencia de Kubernetes totalmente gestionada. Para elegir el modo de funcionamiento de GKE que mejor se adapte a tus cargas de trabajo, consulta Acerca de los modos de funcionamiento de GKE.

Autopilot

  1. En Cloud Shell, ejecuta el siguiente comando:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. Para comunicarte con tu clúster, configura kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

Estándar

  1. En Cloud Shell, crea un clúster estándar que habilite el complemento operador de Ray ejecutando el siguiente comando:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    Este comando también habilita GcsFuseCsiDriver, que permite que los pods monten segmentos de Cloud Storage como sistemas de archivos locales. La creación del clúster puede tardar varios minutos.

  2. Para comunicarte con tu clúster, configura kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. Crea un grupo de nodos de un slice de TPU de varios hosts:

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE aprovisiona un grupo de nodos formado por cuatro máquinas virtuales de TPU Trillium (v6e), que se configuran juntas como un segmento de TPU multihost con una 4x4topología, que está listo para cargas de trabajo de entrenamiento distribuido.

El clúster de GKE con el operador de Ray habilitado instala automáticamente KubeRay y el webhook de TPU de KubeRay en tu clúster.

Configurar un segmento de Cloud Storage y una cuenta de servicio

  1. Crea un segmento de Cloud Storage para los puntos de control compartidos entre los nodos de TPU de varios hosts.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Para habilitar el acceso al segmento de Cloud Storage, crea una cuenta de servicio de Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Para habilitar el acceso al segmento de Cloud Storage, añade los enlaces de política de gestión de identidades y accesos necesarios a la cuenta de servicio:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Crear la secuencia de comandos de entrenamiento

La siguiente secuencia de comandos usa JaxTrainer de Ray Train para ejecutar una tarea de entrenamiento distribuida de MaxText. La secuencia de comandos configura el entorno de entrenamiento para un grupo de nodos de segmento de TPU de varios hosts y ejecuta la tarea de entrenamiento de MaxText en cada nodo de trabajador. La función train_loop_per_worker envuelve el punto de entrada principal de MaxText y usa el programador distribuido de Ray para ejecutar el entrenador de MaxText en un segmento de TPU multihost.

  1. Guarda la siguiente secuencia de comandos de Python como maxtext_ray_trainer.py:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. Para alojar la imagen personalizada, cree un repositorio de Artifact Registry:

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. Para crear una imagen que incluya las dependencias de Ray y MaxText para el entrenamiento, crea un Dockerfile:

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Crea, etiqueta y envía la imagen Docker a Artifact Registry:

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

Preparar el modelo

  1. Guarda el siguiente archivo de manifiesto de ejemplo como maxtext-tpu-cluster.yaml:

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    La especificación de RayCluster anterior crea un grupo de trabajadores de TPU con cuatro trabajadores (numOfHosts: 4) por réplica. Cada trabajador solicita cuatro chips de TPU (google.com/tpu: "4"). Los trabajadores se programarán en un nodo que ejecute TPU Trillium (tpu-v6e-slice) y que forme parte del mismo segmento multihost colocado. KubeRay escala los cuatro trabajadores de forma atómica y GKE inicializa las variables de entorno de JAX necesarias, así como las afinidades de pods para la programación, mediante un webhook de mutación.

  2. Para configurar los valores necesarios en el archivo YAML, crea el RayCluster con envsubst:

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. Comprueba que el clúster esté listo y en ejecución:

    kubectl get rayclusters maxtext-tpu-cluster
    

    La salida debería ser similar a la siguiente:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Para acceder al panel de control de Ray a través del servicio principal de Ray, establece una sesión de reenvío de puertos:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifica que se puede acceder a RayCluster desde tu entorno local:

    ray list nodes --address http://localhost:8265
    

    La salida debería ser similar a la siguiente:

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. Envía la secuencia de comandos JaxTrainer a RayCluster y comprueba que RayJob se completa correctamente:

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    El comando anterior envía la secuencia de comandos de Python, que llama al código JaxTrainer Ray al clúster de Ray. El comando ray job submit incluye algunos argumentos específicos de MaxText que se deben transferir a la configuración del modelo.

    En el terminal, deberías ver un resultado similar al siguiente:

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

Limpieza

Para evitar que se apliquen cargos en tu Google Cloud cuenta por los recursos utilizados en este tutorial, elimina el proyecto que contiene los recursos o conserva el proyecto y elimina los recursos.

  1. Elimina el RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Elimina el clúster de GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Elimina el segmento de Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Elimina el repositorio de Artifact Registry:

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

Siguientes pasos