Treinamento multislice e elástico em TPUs usando o Ray Train no GKE

Este tutorial mostra como treinar modelos de linguagem grandes (LLMs), como o Llama 3 70B, no Google Kubernetes Engine (GKE) usando MaxText, Ray Train e TPUs Trillium multislice. Este tutorial fornece um tutorial completo de ponta a ponta, desde a configuração da rede secundária necessária do data center até o envio e a execução de uma carga de trabalho de treinamento distribuído em 32 chips físicos de TPU.

Este tutorial é destinado a administradores de plataforma, operadores e especialistas em IA que querem aprender a superar os desafios de memória e rede do treinamento de modelos de 70 bilhões de parâmetros em frações de TPU distribuídas e de vários hosts.

Contexto

A combinação de GKE, KubeRay, MaxText e TPUs oferece uma plataforma poderosa e escalonável para treinamento de modelo em grande escala. Esta seção descreve as principais tecnologias usadas neste guia:

JAX

O JAX é uma biblioteca Python para computação de matrizes e transformação de programas orientada a aceleradores, usando o compilador XLA para criar um código altamente otimizado que é dimensionado de maneira eficiente em aceleradores.

MaxText

O MaxText é uma estrutura de LLM de código aberto e alto desempenho projetada para escalonabilidade e personalização. O MaxText é criado com base no JAX e é otimizado para ser executado com eficiência em Cloud TPUs.

TPUs

As Unidades de Processamento de Tensor (TPUs) são aceleradores com design personalizado criados pelo Google para otimizar as cargas de trabalho de machine learning. Ao contrário das CPUs de uso geral ou das GPUs de processamento paralelo, as TPUs são altamente especializadas para cálculos de matrizes e tensores massivos na base do aprendizado profundo, o que as torna eficientes nessa tarefa específica. A principal vantagem das TPUs é o desempenho em escala.

Neste tutorial, usamos a TPU Trillium, a sexta geração de TPUs, em um padrão de implantação Multislice. Na Multislice do Cloud TPU, duas ou mais frações do Cloud TPU se comunicam pela rede do data center (DCN). O Multislice permite um treinamento de pilha completa, econômico e em grande escala com escalonamento vertical quase linear para até dezenas de milhares de chips de TPU. Para mais informações sobre o uso de várias frações, consulte Visão geral do uso de várias frações no Cloud TPU.

KubeRay

O KubeRay é um operador do Kubernetes que oferece uma maneira unificada de implantar, gerenciar e monitorar aplicativos do Ray no Kubernetes. O operador KubeRay é instalado e gerenciado pelo complemento Ray no GKE, que é a maneira recomendada de implantar e gerenciar clusters do Ray no GKE.

Rede de alocação dinâmica de recursos do GKE (DRANET)

O GKE DRANET (Dynamic Resource Allocation Network) é um recurso que conecta dinamicamente dispositivos de rede de alta performance a pods, ignorando a rede padrão do Kubernetes e permitindo alto desempenho na DCN.

Objetivos

Este tutorial mostra como fazer o seguinte:

  1. Configure um cluster do GKE com dois pools de nós de TPU de vários hosts.
  2. Configure uma DCN secundária para comunicação entre frações de TPU.
  3. Configure o KubeRay para gerenciar o ambiente de treinamento distribuído.
  4. Implante um recurso personalizado do RayCluster usando a alocação dinâmica de recursos (DRA, na sigla em inglês) para anexos de rede.
  5. Crie um script de treinamento em Python usando o JaxTrainer do Ray Train para orquestrar o loop de treinamento do MaxText nas fatias de TPU.
  6. Execute um job de treinamento de referência do Llama 3 8B.
  7. Escalonar verticalmente para o Llama 3 70B utilizando a fragmentação 2D (paralelismo de tensor e FSDP) na DCN.

Antes de começar

  • Faça login na sua conta do Google Cloud . Se você começou a usar o Google Cloud, crie uma conta para avaliar o desempenho de nossos produtos em situações reais. Clientes novos também recebem US$ 300 em créditos para executar, testar e implantar cargas de trabalho.
  • Instale a CLI do Google Cloud.

  • Ao usar um provedor de identidade (IdP) externo, primeiro faça login na gcloud CLI com sua identidade federada.

  • Para inicializar a gcloud CLI, execute o seguinte comando:

    gcloud init
  • Crie ou selecione um Google Cloud projeto.

    Funções necessárias para selecionar ou criar um projeto

    • Selecionar um projeto: não é necessário um papel específico do IAM para selecionar um projeto. Você pode escolher qualquer projeto em que tenha recebido um papel.
    • Criar um projeto: para criar um projeto, é necessário ter o papel de Criador de projetos (roles/resourcemanager.projectCreator), que contém a permissão resourcemanager.projects.create. Saiba como conceder papéis.
    • Crie um projeto do Google Cloud :

      gcloud projects create PROJECT_ID

      Substitua PROJECT_ID por um nome para o projeto Google Cloud que você está criando.

    • Selecione o projeto Google Cloud que você criou:

      gcloud config set project PROJECT_ID

      Substitua PROJECT_ID pelo nome do projeto do Google Cloud .

  • Verifique se o faturamento está ativado para o projeto do Google Cloud .

  • Ative as APIs necessárias:

    Funções necessárias para ativar APIs

    Para ativar as APIs, é necessário ter o papel do IAM de administrador de uso do serviço (roles/serviceusage.serviceUsageAdmin), que contém a permissão serviceusage.services.enable. Saiba como conceder papéis.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Instale a CLI do Google Cloud.

  • Ao usar um provedor de identidade (IdP) externo, primeiro faça login na gcloud CLI com sua identidade federada.

  • Para inicializar a gcloud CLI, execute o seguinte comando:

    gcloud init
  • Crie ou selecione um Google Cloud projeto.

    Funções necessárias para selecionar ou criar um projeto

    • Selecionar um projeto: não é necessário um papel específico do IAM para selecionar um projeto. Você pode escolher qualquer projeto em que tenha recebido um papel.
    • Criar um projeto: para criar um projeto, é necessário ter o papel de Criador de projetos (roles/resourcemanager.projectCreator), que contém a permissão resourcemanager.projects.create. Saiba como conceder papéis.
    • Crie um projeto do Google Cloud :

      gcloud projects create PROJECT_ID

      Substitua PROJECT_ID por um nome para o projeto Google Cloud que você está criando.

    • Selecione o projeto Google Cloud que você criou:

      gcloud config set project PROJECT_ID

      Substitua PROJECT_ID pelo nome do projeto do Google Cloud .

  • Verifique se o faturamento está ativado para o projeto do Google Cloud .

  • Ative as APIs necessárias:

    Funções necessárias para ativar APIs

    Para ativar as APIs, é necessário ter o papel do IAM de administrador de uso do serviço (roles/serviceusage.serviceUsageAdmin), que contém a permissão serviceusage.services.enable. Saiba como conceder papéis.

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Atribua papéis à sua conta de usuário. Execute o seguinte comando uma vez para cada um dos seguintes papéis do IAM: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Substitua:

    • PROJECT_ID: o ID do projeto.
    • USER_IDENTIFIER: o identificador da sua conta de usuário . Por exemplo, myemail@example.com.
    • ROLE: o papel do IAM concedido à sua conta de usuário.
  • Como este tutorial usa a TPU Trillium (v6e), selecione uma região ou zona com disponibilidade. Para mais informações, consulte Cotas do Cloud TPU.

Preparar o ambiente

Neste tutorial, você vai usar o Cloud Shell. O Cloud Shell vem pré-instalado com as ferramentas de linha de comando gcloud, helm e kubectl que são usadas neste tutorial.

  1. Acesse o console doGoogle Cloud .

  2. Na parte de cima da janela do console do Google Cloud , clique no botão Ativar Cloud Shell Botão "Ativar shell".

    Uma sessão do Cloud Shell é aberta em um novo frame no consoleGoogle Cloud e exibe um prompt de linha de comando.

  3. No terminal, clone o repositório kubernetes-engine-samples:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. Mude para o diretório que contém os arquivos de exemplo:

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Crie e ative um ambiente virtual Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Instale a CLI do Ray:

    pip install "ray[default]==2.55.0"
    
  7. Configure as variáveis de ambiente a seguir:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    Substitua:

    • GS_BUCKET: o nome do bucket do Cloud Storage.
    • KSA_NAME: o nome da conta de serviço do Kubernetes.
    • CLUSTER_NAME: o nome do novo cluster;
    • REGION: a região em que sua capacidade de TPU Trillium está disponível.
    • ZONE: a zona em que sua capacidade de TPU Trillium está disponível. Para mais informações, consulte Disponibilidade da TPU no GKE.

Configurar a rede de cluster para o Cloud TPU de várias frações

Em uma fração de TPU de vários hosts, os dispositivos de TPU se comunicam pelas interconexões de alta velocidade entre chips. No entanto, ao executar jobs de várias frações, as frações de TPU precisam se comunicar pela DCN. As redes de pods padrão do Kubernetes podem restringir esse tráfego. O tipo de máquina ct6e-standard-4t é compatível com várias placas de rede (NICs) físicas. Para ter o melhor desempenho, crie mais duas redes VPC e use o GKE DRANET para conectá-las diretamente aos pods do Ray.

  1. Crie as duas redes VPC adicionais com uma grande unidade máxima de treinamento (MTU):

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. Crie as sub-redes dedicadas:

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

Criar um cluster do GKE

É possível configurar o KubeRay em TPUs em um cluster do GKE Autopilot ou Standard. Recomendamos que você use um cluster do Autopilot para ter uma experiência totalmente gerenciada do Kubernetes. Para escolher o modo de operação do GKE mais adequado para suas cargas de trabalho, consulte Sobre os modos de operação do GKE.

Para usar o DRANET gerenciado pelo GKE, seu cluster precisa usar a versão 1.35.2-gke.1842000 ou mais recente no modo Autopilot ou a versão 1.34.1-gke.1829001 ou mais recente no modo Standard. Este tutorial usa a versão 1.35.2-gke.1842000.

Piloto automático

  1. No Cloud Shell, execute este comando:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. Para se comunicar com o cluster, configure kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

Padrão

  1. No Cloud Shell, crie um cluster Standard que ative o complemento do operador do Ray executando o seguinte comando:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    Esse comando também ativa o GcsFuseCsiDriver, que permite que os pods ativem buckets do Cloud Storage como sistemas de arquivos locais. A criação do cluster pode levar vários minutos.

  2. Para se comunicar com o cluster, configure kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. Crie o primeiro pool de nós de fração de TPU de vários hosts com o DRANET do GKE ativado:

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. Crie o segundo pool de nós de fração da TPU:

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

O GKE provisiona um pool de nós com quatro VMs de TPU Trillium (v6e), que são configuradas juntas como uma fração de TPU de vários hosts com uma topologia 4x4. Esse pool de nós está pronto para cargas de trabalho de treinamento distribuído.

O cluster do GKE com o operador Ray ativado instala automaticamente o KubeRay e o webhook de TPU do KubeRay no cluster.

Configurar um bucket do Cloud Storage e uma conta de serviço

  1. Crie um bucket do Cloud Storage para checkpoints compartilhados entre os nós de TPU de vários hosts.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Para ativar o acesso ao bucket do Cloud Storage, crie uma conta de serviço do Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Para ativar o acesso ao bucket do Cloud Storage, adicione as vinculações de política do IAM necessárias à conta de serviço:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

criar o script de treinamento;

O script maxtext_multi_slice_trainer.py usa o JaxTrainer do Ray Train para executar um job de treinamento distribuído do MaxText em duas fatias de TPU. O script configura o ambiente de treinamento para oito workers de TPU de vários hosts e executa o job de treinamento do MaxText em cada nó de trabalho. A função train_loop_per_worker encapsula o ponto de entrada principal do MaxText e usa o programador distribuído do Ray para executar o treinador do MaxText em uma fração de TPU de vários hosts:

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

O script anterior define uma instância JaxTrainer que solicita oito workers e uma topologia de 4x4. Internamente, o Ray provisiona um SlicePlacementGroup nas duas frações de TPU e ajuda a garantir que os workers do Ray Train sejam executados atomicamente nas duas frações, com um worker por host.

Treine o modelo

  1. O manifesto ray-cluster.tpu-multi-slice.yaml no diretório atual define o recurso personalizado do RayCluster. Este manifesto inclui o DRANET ResourceClaimTemplate para provisionar os dispositivos de rede para DRANET e Multislice do GKE:

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    A especificação RayCluster anterior cria um grupo de workers da TPU com oito workers (numOfHosts: 4) por réplica e duas réplicas. Cada worker solicita quatro chips de TPU (google.com/tpu: "4"). Cada worker é programado em um nó Trillium de TPU (tpu-v6e-slice), que faz parte da mesma fração de vários hosts colocada. O KubeRay escalona todos os quatro workers em uma fração de forma atômica. As variáveis de ambiente JAX necessárias, bem como as afinidades de pod para programação, são inicializadas pelo GKE usando um webhook mutável.

  2. Para criar o RayCluster, aplique o manifesto:

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. Verifique se o cluster está pronto e em execução:

    kubectl get rayclusters maxtext-tpu-cluster
    

    A saída será semelhante a esta:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. Para acessar o painel do Ray pelo serviço de cabeçalho do Ray, estabeleça uma sessão de encaminhamento de portas:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifique se o RayCluster pode ser acessado no seu ambiente local:

    ray list nodes --address http://localhost:8265
    

    A saída será semelhante a esta:

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. Faça o download do arquivo de configuração base do MaxText. Esse arquivo é obrigatório para que o script de treinamento defina os hiperparâmetros padrão do modelo:

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. Envie o script JaxTrainer para o RayCluster e verifique se o RayJob foi concluído:

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70 B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

O comando anterior envia o script Python, que chama o código do JaxTrainer Ray para o RayCluster. O comando ray job submit inclui alguns argumentos específicos do MaxText para transmitir à configuração do modelo.

No terminal, você vai ver uma saída semelhante a esta para o job do Llama 3 70B:

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

Executar treinamento elástico do Multislice em VMs spot

Ao usar aceleradores muito procurados, como TPUs, o uso de VMs spot pode reduzir significativamente os custos. No entanto, as VMs spot podem ser interrompidas inesperadamente.

O Ray Train oferece suporte ao treinamento elástico, que permite que seu job dimensione dinamicamente o número de frações de TPU participantes para cima ou para baixo sem falhar. Se uma fração for interrompida, o Ray vai pausar o loop de treinamento, esperar que os workers restantes se reorganizem, restaurar do checkpoint mais recente do MaxText e retomar o treinamento com a pegada menor.

Para ativar o treinamento elástico, mude o parâmetro num_workers no seu ScalingConfig de um número inteiro estático para uma tupla que representa (minimum_workers, maximum_workers). Além disso, adicione um FailureConfig(max_failures=3) ao RunConfig, que instrui o Ray Train a tentar novamente o loop de treinamento até três vezes em vez de falhar completamente o job quando um worker é desalojado.

Atualizar o script do Ray Train

  1. O script maxtext_elastic_trainer.py no diretório atual ativa o treinamento elástico. Observe que ele define num_workers=(4,8), que instrui o Ray a continuar se pelo menos uma fração de 16 chips (quatro workers) estiver disponível, mas a escalonar verticalmente para duas frações (oito workers) se possível. Ele inclui um FailureConfig para ativar o treinamento elástico, definir o número de novas tentativas e ajudar a garantir que o job sobreviva a remoções:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Envie o job usando a CLI do Ray Job. Forneça um run_name exclusivo para que os pontos de verificação não entrem em conflito com execuções anteriores.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. Para simular uma interrupção ou preempção de nó durante o treinamento, exclua um pod.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

O terminal registra uma falha do worker, mas o controlador de orquestração mantém o job ativo e retoma automaticamente do ponto de verificação /data/rayjob-elastic-8b/checkpoints depois que a topologia mínima estiver disponível.

Como o MaxText recalcula dinamicamente a malha de dispositivos ao retomar, não é necessário escrever uma lógica personalizada para processar o refragmentação de pontos de verificação quando a topologia diminui. O verificador de pontos de verificação do Orbax do JAX vai refragmentar automaticamente os pesos salvos no novo layout físico antes de continuar o loop de treinamento. A saída a seguir mostra o controlador do Ray Train detectando recursos de TPU recém-disponíveis no cluster e realizando uma operação de escalonamento de uma fração (quatro trabalhadores) para duas frações (oito trabalhadores) durante o treinamento.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

Limpar

Para evitar cobranças na sua conta do Google Cloud pelos recursos usados neste tutorial, exclua o projeto que contém os recursos ou mantenha o projeto e exclua os recursos individuais.

  1. Exclua o RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Exclua o cluster do GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Exclua o bucket do Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    

A seguir