Treinar um LLM usando JAX, Ray Train e TPU Trillium no GKE

Neste tutorial, você vai aprender a treinar o modelo de linguagem grande (LLM) Llama 3 8B no Google Kubernetes Engine (GKE) usando MaxText, Ray Train e TPUs.

Este tutorial oferece um tutorial completo, de ponta a ponta, desde a configuração da infraestrutura em nuvem necessária até o envio e a execução bem-sucedida da carga de trabalho de treinamento em TPUs de vários hosts.

Este tutorial é destinado a administradores e operadores de plataforma e especialistas em dados e IA que querem aprender a treinar modelos grandes em uma fração de TPU distribuída e com vários hosts.

Contexto

A combinação de GKE, KubeRay, MaxText e TPUs oferece uma plataforma poderosa e escalonável para treinamento de modelo em grande escala. Esta seção descreve as principais tecnologias usadas neste guia:

JAX

O JAX é uma biblioteca Python para computação de matrizes e transformação de programas orientada a aceleradores, projetada para computação numérica de alto desempenho e machine learning em grande escala.

O JAX oferece um sistema extensível para transformar funções numéricas como jax.grad, jax.jit e jax.vmap, usando o compilador XLA para criar código altamente otimizado que é dimensionado de maneira eficiente em aceleradores como GPUs e TPUs. O principal poder do JAX está na capacidade de composição, que permite aos usuários combinar essas transformações para criar programas numéricos complexos e de alto desempenho para execução distribuída.

MaxText

O MaxText é um modelo de linguagem grande (LLM) de código aberto e alto desempenho projetado para escalonabilidade e personalização. O MaxText é criado com base no JAX e otimizado para ser executado com eficiência na Cloud TPU e GPUs.

TPUs

As Unidades de Processamento de Tensor (TPUs) são aceleradores personalizados criados pelo Google para otimizar as cargas de trabalho de machine learning. Ao contrário das CPUs de uso geral ou das GPUs de processamento paralelo, as TPUs são altamente especializadas para os cálculos de matriz e tensor massivos na base do aprendizado profundo, o que as torna eficientes nessa tarefa específica. A principal vantagem das TPUs é o desempenho em escala.

Este tutorial usa a TPU Trillium, que é a sexta geração de TPUs. Para mais informações, consulte Benefícios de usar o Trillium do TPU.

KubeRay

O KubeRay é um operador do Kubernetes que oferece uma maneira unificada de implantar, gerenciar e monitorar aplicativos do Ray no Kubernetes. O operador KubeRay é instalado e gerenciado pelo complemento Ray no GKE, que é a maneira recomendada de implantar e gerenciar clusters do Ray no GKE.

Objetivos

Este tutorial mostra como fazer o seguinte:

  1. Configure um cluster do GKE com um pool de nós de TPU de vários hosts.
  2. Configure o KubeRay para gerenciar o ambiente de treinamento distribuído.
  3. Crie uma imagem do Docker personalizada que contenha as dependências do MaxText, do Ray e do JAX.
  4. Crie um script de treinamento em Python que use o JaxTrainer do Ray Train para orquestrar o loop de treinamento do MaxText na fração de TPU.
  5. Defina um recurso personalizado RayCluster para provisionar os nós principais e de worker com os recursos de TPU necessários.
  6. Envie o job de treinamento para o RayCluster e monitore o progresso dele.
  7. Use o Cloud Storage para armazenar checkpoints do modelo.

Antes de começar

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • Ao usar um provedor de identidade (IdP) externo, primeiro faça login na gcloud CLI com sua identidade federada.

  • Para inicializar a gcloud CLI, execute o seguinte comando:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • Ao usar um provedor de identidade (IdP) externo, primeiro faça login na gcloud CLI com sua identidade federada.

  • Para inicializar a gcloud CLI, execute o seguinte comando:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • Como este tutorial usa a TPU Trillium (v6e), selecione uma região ou zona com disponibilidade. Para mais informações, consulte Cotas do Cloud TPU.

Preparar o ambiente

Neste tutorial, você vai usar o Cloud Shell. O Cloud Shell vem pré-instalado com as ferramentas de linha de comando gcloud, helm e kubectl que são usadas neste tutorial.

  1. Acesse o console doGoogle Cloud .

  2. Na parte de cima da janela do console do Google Cloud , clique no botão Ativar Cloud Shell Botão "Ativar shell".

    Uma sessão do Cloud Shell é aberta em um novo frame no consoleGoogle Cloud e exibe um prompt de linha de comando.

  3. Crie e ative um ambiente virtual Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. Instale a CLI do Ray e outras dependências:

    pip install "ray[default]==2.49.1"
    
  5. Configure as variáveis de ambiente a seguir:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    Substitua:

    • GS_BUCKET: o nome do bucket do Cloud Storage.
    • KSA_NAME: o nome da conta de serviço do Kubernetes.
    • CLUSTER_NAME: o nome do novo cluster;
    • REGION: a região em que sua capacidade de TPU Trillium está disponível.
    • ZONE: a zona em que sua capacidade de TPU Trillium está disponível. Para mais informações, consulte Disponibilidade da TPU no GKE.
    • ARTIFACT_REGISTRY: o nome do repositório do Artifact Registry.

Criar um cluster do GKE

É possível configurar o KubeRay em TPUs em um cluster do GKE Autopilot ou Standard. Recomendamos que você use um cluster do Autopilot para ter uma experiência totalmente gerenciada do Kubernetes. Para escolher o modo de operação do GKE mais adequado para suas cargas de trabalho, consulte Sobre os modos de operação do GKE.

Piloto automático

  1. No Cloud Shell, execute este comando:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. Para se comunicar com o cluster, configure kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

Padrão

  1. No Cloud Shell, crie um cluster Standard que ative o complemento do operador do Ray executando o seguinte comando:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    Esse comando também ativa o GcsFuseCsiDriver, que permite que os pods ativem buckets do Cloud Storage como sistemas de arquivos locais. A criação do cluster pode levar vários minutos.

  2. Para se comunicar com o cluster, configure kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. Crie um pool de nós de fração de TPU com vários hosts:

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

O GKE provisiona um pool de nós com quatro VMs de TPU Trillium (v6e), que são configuradas juntas como uma fração de TPU de vários hosts, com uma topologia 4x4, pronta para cargas de trabalho de treinamento distribuído.

O cluster do GKE com o operador do Ray ativado instala automaticamente o KubeRay e o webhook de TPU do KubeRay no cluster.

Configurar um bucket do Cloud Storage e uma conta de serviço

  1. Crie um bucket do Cloud Storage para checkpoints compartilhados entre os nós de TPU de vários hosts.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Para ativar o acesso ao bucket do Cloud Storage, crie uma conta de serviço do Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Para ativar o acesso ao bucket do Cloud Storage, adicione as vinculações de política do IAM necessárias à conta de serviço:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

criar o script de treinamento;

O script a seguir usa o JaxTrainer do Ray Train para executar um job de treinamento distribuído do MaxText. O script configura o ambiente de treinamento para um pool de nós de fração de TPU de vários hosts e executa o job de treinamento do MaxText em cada nó de worker. A função train_loop_per_worker encapsula o ponto de entrada principal do MaxText e usa o programador distribuído do Ray para executar o treinador do MaxText em uma fração de TPU de vários hosts.

  1. Salve o script Python a seguir como maxtext_ray_trainer.py:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. Para hospedar a imagem personalizada, crie um repositório do Artifact Registry:

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. Para criar uma imagem que inclua dependências do Ray e do MaxText para treinamento, crie um Dockerfile:

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. Crie, atribua uma tag e envie a imagem do Docker para o Artifact Registry:

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

Treine o modelo

  1. Salve o seguinte manifesto de amostra como maxtext-tpu-cluster.yaml:

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    A especificação do RayCluster anterior cria um grupo de workers de TPU com quatro workers (numOfHosts: 4) por réplica. Cada worker solicita quatro chips de TPU (google.com/tpu: "4"). Os workers serão programados em um nó que executa TPU Trillium (tpu-v6e-slice), e isso faz parte da mesma fração de vários hosts colocados. O KubeRay escalona todos os quatro workers atomicamente, e as variáveis de ambiente JAX necessárias, bem como as afinidades de pod para programação, são inicializadas pelo GKE por um webhook de mutação.

  2. Para configurar os valores necessários no arquivo YAML, crie o RayCluster usando envsubst:

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. Verifique se o cluster está pronto e em execução:

    kubectl get rayclusters maxtext-tpu-cluster
    

    A saída será semelhante a esta:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. Para acessar o painel do Ray pelo serviço de cabeçalho do Ray, estabeleça uma sessão de encaminhamento de portas:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. Verifique se o RayCluster pode ser acessado no ambiente local:

    ray list nodes --address http://localhost:8265
    

    A saída será semelhante a esta:

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. Envie o script JaxTrainer para o RayCluster e verifique se o RayJob foi concluído:

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    O comando anterior envia o script Python, que chama o código do JaxTrainer Ray para o RayCluster. O comando ray job submit inclui alguns argumentos específicos do MaxText para transmitir à configuração do modelo.

    No terminal, você vai encontrar uma saída semelhante a esta:

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

Limpar

Para evitar cobranças na sua conta do Google Cloud pelos recursos usados neste tutorial, exclua o projeto que contém os recursos ou mantenha o projeto e exclua os recursos individuais.

  1. Exclua o RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. Exclua o cluster do GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Exclua o bucket do Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    
  4. Exclua o repositório do Artifact Registry:

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

A seguir