Entrenamiento elástico y de Multislice en TPU con Ray Train en GKE

En este instructivo, se muestra cómo entrenar modelos de lenguaje grandes (LLM) como Llama 3 70B en Google Kubernetes Engine (GKE) con MaxText, Ray Train y TPU Trillium de Multislice. En este instructivo, se proporciona una guía completa de extremo a extremo, desde la configuración de las redes secundarias necesarias del centro de datos hasta el envío y la ejecución exitosa de una carga de trabajo de entrenamiento distribuido en 32 chips de TPU físicos.

Este instructivo está dirigido a los administradores, operadores y especialistas en IA de la plataforma que desean aprender a superar los desafíos de memoria y redes del entrenamiento de modelos con 70,000 millones de parámetros en porciones de TPU distribuidas y de varios hosts.

Fondo

La combinación de GKE, KubeRay, MaxText y las TPU proporciona una plataforma potente y escalable para el entrenamiento de modelos a gran escala. En esta sección, se describen las tecnologías clave que se usan en esta guía:

JAX

JAX es una biblioteca de Python para la transformación de programas y el procesamiento de arrays orientados a aceleradores, que utiliza el compilador XLA para crear código altamente optimizado que se adapta de manera eficiente a los aceleradores.

MaxText

MaxText es un marco de trabajo de LLM de código abierto y alto rendimiento diseñado para la escalabilidad y la personalización. MaxText se basa en JAX y se optimizó para ejecutarse de manera eficiente en Cloud TPU.

TPU

Las unidades de procesamiento tensorial (TPU) son aceleradores diseñados de forma personalizada y creados por Google para optimizar las cargas de trabajo de aprendizaje automático. A diferencia de las CPU de uso general o las GPU de procesamiento paralelo, las TPU están altamente especializadas para los cálculos masivos de matrices y tensores que son la base del aprendizaje profundo, lo que las hace eficientes para esta tarea específica. La principal ventaja de las TPU es el rendimiento a gran escala.

En este instructivo, se usa TPU Trillium, la sexta generación de TPU, en un patrón de implementación de Multislice. Cloud TPU Multislice es donde dos o más porciones de Cloud TPU se comunican a través de la red del centro de datos (DCN). Multislice permite el entrenamiento de pila completa, rentable y a gran escala con escalamiento casi lineal hasta decenas de miles de chips TPU. Para obtener más información sobre Multislice, consulta Descripción general de Cloud TPU Multislice.

KubeRay

KubeRay es un operador de Kubernetes que proporciona una forma unificada de implementar, administrar y supervisar aplicaciones de Ray en Kubernetes. El operador de KubeRay se instala y administra a través del complemento Ray en GKE, que es la forma recomendada de implementar y administrar clústeres de Ray en GKE.

Red de asignación dinámica de recursos de GKE (DRANET)

GKE DRANET (red de asignación dinámica de recursos) es una función que conecta de forma dinámica dispositivos de red de alto rendimiento a los Pods, lo que omite las redes estándar de Kubernetes y habilita el alto rendimiento en la DCN.

Objetivos

En este instructivo, se muestra cómo realizar lo siguiente:

  1. Configura un clúster de GKE con dos grupos de nodo TPU de varios hosts.
  2. Configura una DCN secundaria para la comunicación entre porciones de TPU.
  3. Configura KubeRay para administrar el entorno de entrenamiento distribuido.
  4. Implementa un recurso personalizado de RayCluster con la asignación dinámica de recursos (DRA) para las vinculaciones de red.
  5. Crea una secuencia de comandos de entrenamiento de Python con JaxTrainer de Ray Train para coordinar el bucle de entrenamiento de MaxText en las porciones de TPU.
  6. Ejecuta un trabajo de entrenamiento del modelo de referencia Llama 3 8B.
  7. Escala verticalmente hasta Llama 3 70B con la fragmentación 2D (paralelismo de tensores y FSDP) en la DCN.

Antes de comenzar

  • Accede a tu cuenta de Google Cloud . Si eres nuevo en Google Cloud, crea una cuenta para evaluar el rendimiento de nuestros productos en situaciones reales. Los clientes nuevos también obtienen $300 en créditos gratuitos para ejecutar, probar y, además, implementar cargas de trabajo.
  • Instala Google Cloud CLI.

  • Si usas un proveedor de identidad externo (IdP), primero debes acceder a la gcloud CLI con tu identidad federada.

  • Para inicializar gcloud CLI, ejecuta el siguiente comando:

    gcloud init
  • Crea o selecciona un Google Cloud proyecto.

    Roles necesarios para seleccionar o crear un proyecto

    • Selecciona un proyecto: Para seleccionar un proyecto, no se requiere un rol de IAM específico. Puedes seleccionar cualquier proyecto en el que se te haya otorgado un rol.
    • Crear un proyecto: Para crear un proyecto, necesitas el rol de Creador de proyectos (roles/resourcemanager.projectCreator), que contiene el permiso resourcemanager.projects.create. Obtén más información para otorgar roles.
    • Crea un Google Cloud proyecto:

      gcloud projects create PROJECT_ID

      Reemplaza PROJECT_ID por un nombre para el proyecto Google Cloud que estás creando.

    • Selecciona el proyecto Google Cloud que creaste:

      gcloud config set project PROJECT_ID

      Reemplaza PROJECT_ID por el nombre de tu Google Cloud proyecto.

  • Verifica que la facturación esté habilitada para tu proyecto de Google Cloud .

  • Habilita las APIs necesarias:

    Roles necesarios para habilitar las APIs

    Para habilitar las APIs, necesitas el rol de IAM de administrador de Service Usage (roles/serviceusage.serviceUsageAdmin), que contiene el permiso serviceusage.services.enable. Obtén más información para otorgar roles.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Instala Google Cloud CLI.

  • Si usas un proveedor de identidad externo (IdP), primero debes acceder a la gcloud CLI con tu identidad federada.

  • Para inicializar gcloud CLI, ejecuta el siguiente comando:

    gcloud init
  • Crea o selecciona un Google Cloud proyecto.

    Roles necesarios para seleccionar o crear un proyecto

    • Selecciona un proyecto: Para seleccionar un proyecto, no se requiere un rol de IAM específico. Puedes seleccionar cualquier proyecto en el que se te haya otorgado un rol.
    • Crear un proyecto: Para crear un proyecto, necesitas el rol de Creador de proyectos (roles/resourcemanager.projectCreator), que contiene el permiso resourcemanager.projects.create. Obtén más información para otorgar roles.
    • Crea un Google Cloud proyecto:

      gcloud projects create PROJECT_ID

      Reemplaza PROJECT_ID por un nombre para el proyecto Google Cloud que estás creando.

    • Selecciona el proyecto Google Cloud que creaste:

      gcloud config set project PROJECT_ID

      Reemplaza PROJECT_ID por el nombre de tu Google Cloud proyecto.

  • Verifica que la facturación esté habilitada para tu proyecto de Google Cloud .

  • Habilita las APIs necesarias:

    Roles necesarios para habilitar las APIs

    Para habilitar las APIs, necesitas el rol de IAM de administrador de Service Usage (roles/serviceusage.serviceUsageAdmin), que contiene el permiso serviceusage.services.enable. Obtén más información para otorgar roles.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Otorga roles a tu cuenta de usuario. Ejecuta el siguiente comando una vez para cada uno de los siguientes roles de IAM: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Reemplaza lo siguiente:

    • PROJECT_ID: ID del proyecto
    • USER_IDENTIFIER: Es el identificador de tu cuenta de usuario de . Por ejemplo, myemail@example.com.
    • ROLE: Es el rol de IAM que otorgas a tu cuenta de usuario.
  • Como en este instructivo se usa la TPU Trillium (v6e), selecciona una región o zona con disponibilidad. Para obtener más información, consulta Cuotas de Cloud TPU.

Prepara el entorno

En este instructivo, usarás Cloud Shell. Cloud Shell viene preinstalado con las herramientas de línea de comandos de gcloud, helm y kubectl que se usan en este instructivo.

  1. Ve a la consola deGoogle Cloud .

  2. En la parte superior de la Google Cloud ventana de la consola, haz clic en el botón Activar Cloud Shell Botón de activar Shell.

    Se abrirá una sesión de Cloud Shell en un marco nuevo en la consola deGoogle Cloud , y se mostrará una ventana de línea de comandos.

  3. En tu terminal, clona el repositorio kubernetes-engine-samples:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. Cambia al directorio que contiene los archivos de muestra:

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Crea y activa un entorno virtual de Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Instala la CLI de Ray:

    pip install "ray[default]==2.55.0"
    
  7. Configura las siguientes variables de entorno:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    Reemplaza lo siguiente:

    • GS_BUCKET: Es el nombre del bucket de Cloud Storage.
    • KSA_NAME: Es el nombre de la cuenta de servicio de Kubernetes.
    • CLUSTER_NAME es el nombre del clúster nuevo.
    • REGION: Es la región en la que está disponible tu capacidad de TPU Trillium.
    • ZONE: Es la zona en la que está disponible tu capacidad de TPU Trillium. Para obtener más información, consulta la disponibilidad de TPU en GKE.

Configura la red del clúster para Cloud TPU Multislice

Dentro de una porción de TPU de varios hosts, los dispositivos de TPU se comunican a través de las interconexiones entre chips de alta velocidad. Sin embargo, cuando se ejecutan trabajos de Multislice, las porciones de TPU deben comunicarse entre sí a través de la DCN. Las redes de Pod de Kubernetes estándar pueden generar un cuello de botella en este tráfico. El tipo de máquina ct6e-standard-4t está respaldado por varias tarjetas de interfaz de red (NIC) físicas. Para lograr el mejor rendimiento, crea dos redes de VPC adicionales y usa GKE DRANET para conectarlas directamente a los Pods de Ray.

  1. Crea las dos redes de VPC adicionales con una unidad de transmisión máxima (MTU) grande:

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. Crea las subredes dedicadas:

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

Crea un clúster de GKE

Puedes configurar KubeRay en TPU en un clúster de GKE Autopilot o Standard. Te recomendamos que uses un clúster de Autopilot para una experiencia de Kubernetes completamente administrada. Para elegir el modo de operación de GKE que se adapte mejor a tus cargas de trabajo, consulta Acerca de los modos de operación de GKE.

Para usar DRANET administrado por GKE, tu clúster debe usar la versión 1.35.2-gke.1842000 o posterior para el modo Autopilot, o la versión 1.34.1-gke.1829001 o posterior para el modo Standard. En este instructivo, se usa la versión 1.35.2-gke.1842000.

Autopilot

  1. En Cloud Shell, ejecuta el siguiente comando:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. Para comunicarte con tu clúster, configura kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

Estándar

  1. En Cloud Shell, ejecuta el siguiente comando para crear un clúster estándar que habilite el complemento Ray operator:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    Este comando también habilita GcsFuseCsiDriver, lo que permite que los Pods activen buckets de Cloud Storage como sistemas de archivos locales. La creación del clúster puede tomar varios minutos.

  2. Para comunicarte con tu clúster, configura kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. Crea el primer grupo de nodos de porción de TPU multihost con GKE DRANET habilitado:

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. Crea el segundo grupo de nodos de porción de TPU:

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE aprovisiona un grupo de nodos que consta de cuatro VMs de TPU Trillium (v6e), que se configuran juntas como una porción de TPU de varios hosts que tiene una topología 4x4. Este grupo de nodos está listo para las cargas de trabajo de entrenamiento distribuido.

El clúster de GKE habilitado para el operador de Ray instala automáticamente KubeRay y el webhook de KubeRay TPU en tu clúster.

Configura un bucket de Cloud Storage y una cuenta de servicio

  1. Crea un bucket de Cloud Storage para los puntos de control compartidos entre los nodos de TPU de varios hosts.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Para habilitar el acceso al bucket de Cloud Storage, crea una cuenta de servicio de Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Para habilitar el acceso al bucket de Cloud Storage, agrega las vinculaciones de políticas de IAM necesarias a la cuenta de servicio:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

Crear la secuencia de comandos de entrenamiento

La secuencia de comandos maxtext_multi_slice_trainer.py usa JaxTrainer de Ray Train para ejecutar un trabajo de entrenamiento distribuido de MaxText en dos segmentos de TPU. La secuencia de comandos configura el entorno de entrenamiento para ocho trabajadores TPU de varios hosts y ejecuta el trabajo de entrenamiento de MaxText en cada nodo trabajador. La función train_loop_per_worker encapsula el punto de entrada principal de MaxText y usa el programador distribuido de Ray para ejecutar el entrenador de MaxText en una porción de TPU de varios hosts:

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

La secuencia de comandos anterior define una instancia de JaxTrainer que solicita ocho trabajadores y una topología de 4x4. Internamente, Ray aprovisiona un SlicePlacementGroup en las dos porciones de TPU y ayuda a garantizar que los trabajadores de Ray Train se ejecuten de forma atómica en ambas porciones, con un trabajador por host.

Entrena el modelo

  1. El manifiesto ray-cluster.tpu-multi-slice.yaml en el directorio actual define el recurso personalizado de RayCluster. Este manifiesto incluye el objeto DRANET ResourceClaimTemplate para aprovisionar los dispositivos de red para DRANET de GKE y Multislice:

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    La especificación de RayCluster anterior crea un grupo de trabajadores de TPU con ocho trabajadores de TPU (numOfHosts: 4) por réplica, con dos réplicas. Cada trabajador solicita cuatro chips TPU (google.com/tpu: "4"). Cada trabajador se programa en un nodo TPU Trillium (tpu-v6e-slice), que forma parte de la misma porción multihost ubicada en el mismo lugar. KubeRay escala los cuatro trabajadores de una porción de forma atómica. GKE inicializa las variables de entorno de JAX requeridas, así como las afinidades de Pod para la programación, a través de un webhook de mutación.

  2. Para crear el RayCluster, aplica el manifiesto:

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. Verifica que el clúster esté listo y en ejecución:

    kubectl get rayclusters maxtext-tpu-cluster
    

    El resultado debería ser similar al siguiente ejemplo:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. Para acceder al panel de Ray a través del servicio principal de Ray, establece una sesión de redirección de puertos:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifica que se pueda acceder a RayCluster desde tu entorno local:

    ray list nodes --address http://localhost:8265
    

    El resultado debería ser similar al siguiente ejemplo:

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. Descarga el archivo de configuración base de MaxText. El script de entrenamiento requiere este archivo para establecer los hiperparámetros predeterminados del modelo:

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. Envía el script de JaxTrainer al RayCluster y verifica que el RayJob se complete correctamente:

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

El comando anterior envía la secuencia de comandos de Python, que llama al código de JaxTrainer Ray al clúster de Ray. El comando ray job submit incluye algunos argumentos específicos de MaxText para pasar a la configuración del modelo.

En el terminal, deberías ver un resultado similar al siguiente para el trabajo de Llama 3 70B:

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

Ejecuta el entrenamiento elástico de Multislice en VMs Spot

Cuando se usan aceleradores muy solicitados, como las TPU, utilizar VMs Spot podría reducir significativamente los costos. Sin embargo, las VMs Spot se pueden interrumpir de forma inesperada.

Ray Train admite el entrenamiento elástico, lo que permite que tu trabajo escale de forma dinámica la cantidad de segmentos de TPU participantes hacia arriba o hacia abajo sin fallar. Si se interrumpe una división, Ray pausa el bucle de entrenamiento, espera a que se reorganice el resto de los trabajadores, restablece el último punto de control de MaxText y reanuda el entrenamiento con la huella más pequeña.

Para habilitar el entrenamiento elástico, cambia el parámetro num_workers en tu ScalingConfig de un número entero estático a una tupla que represente (minimum_workers, maximum_workers). Además, agrega un FailureConfig(max_failures=3) al RunConfig, que le indica a Ray Train que vuelva a intentar el bucle de entrenamiento hasta 3 veces en lugar de hacer que falle todo el trabajo cuando se interrumpe un trabajador.

Actualiza el script de Ray Train

  1. La secuencia de comandos maxtext_elastic_trainer.py en el directorio actual habilita el entrenamiento elástico. Ten en cuenta que establece num_workers=(4,8), lo que le indica a Ray que continúe si hay disponible al menos una porción de 16 chips (cuatro trabajadores), pero que escale verticalmente a dos porciones (ocho trabajadores) si es posible. Incluye un FailureConfig para habilitar el entrenamiento elástico, definir la cantidad de reintentos y ayudar a garantizar que el trabajo sobreviva a las interrupciones:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Envía el trabajo con la CLI de Ray Job. Asegúrate de proporcionar un run_name único para que los puntos de control no entren en conflicto con ejecuciones anteriores.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. Para simular la finalización o interrupción de un nodo durante el entrenamiento, borra un Pod.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

La terminal registra una falla del trabajador, pero el controlador de orquestación mantiene activo el trabajo y se reanuda automáticamente desde el punto de control /data/rayjob-elastic-8b/checkpoints después de que la topología mínima esté disponible.

Dado que MaxText vuelve a calcular de forma dinámica la malla de dispositivos cuando se reanuda el entrenamiento, no necesitas escribir ninguna lógica personalizada para controlar el nuevo fragmentado de puntos de control cuando se reduce la topología. El verificador de Orbax de JAX volverá a fragmentar automáticamente los pesos guardados en el nuevo diseño físico antes de continuar con el bucle de entrenamiento. En el siguiente resultado, se muestra que el controlador de Ray Train detecta recursos de TPU disponibles recientemente en el clúster y realiza una operación de ajuste de escala de una porción (cuatro trabajadores) a dos porciones (ocho trabajadores) durante el entrenamiento.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

Realiza una limpieza

Para evitar que se apliquen cargos a tu Google Cloud cuenta por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

  1. Borra el RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Borra el clúster de GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Borra el bucket de Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    

¿Qué sigue?