Ajustar um LLM usando TPUs no GKE com o JAX

Neste tutorial, você vai aprender a ajustar um modelo de linguagem grande (LLM) usando Unidades de Processamento de Tensor (TPUs) no Google Kubernetes Engine (GKE) com JAX. Com o ajuste fino, é possível adaptar um modelo de fundação, como o Gemma 3, a um domínio ou tarefa específica. Esse processo melhora a precisão e a acurácia do modelo atualizando os parâmetros com seu próprio conjunto de dados especializado.

Este guia é um bom ponto de partida se você precisar do controle granular, da personalização, da escalonabilidade, da resiliência, da portabilidade e da economia do Kubernetes gerenciado ao ajustar suas cargas de trabalho de IA/ML.

Contexto

Ao usar TPUs no GKE com o Jax para ajustar um LLM, é possível criar uma solução de ajuste robusta e pronta para produção com todos os benefícios do Kubernetes gerenciado.

Gemma

O Gemma é um conjunto de modelos multimodais de IA generativa, leve e abertamente disponíveis, lançados sob licença aberta. Esses modelos de IA estão disponíveis para execução em aplicativos, hardware, dispositivos móveis ou serviços hospedados. O Gemma 3 apresenta a multimodalidade e oferece suporte a entradas de linguagem de visão e saídas de texto. Ele processa janelas de contexto de até 128.000 tokens e é compatível com mais de 140 idiomas. O Gemma 3 também oferece recursos aprimorados de matemática, raciocínio e chat, incluindo saídas estruturadas e chamadas de função.

É possível usar os modelos Gemma para geração de texto ou ajustá-los para tarefas especializadas.

Para mais informações, consulte a documentação do Gemma.

TPUs

TPUs são circuitos integrados de aplicação específica (ASICs) desenvolvidos especialmente pelo Google para acelerar modelos de machine learning e de IA criados com o uso de frameworks como TensorFlow, PyTorch e JAX.

Antes de usar TPUs no GKE, recomendamos que você conclua o seguinte programa de aprendizado:

  1. Saiba mais sobre a disponibilidade atual da versão da TPU com a arquitetura do sistema do Cloud TPU.
  2. Saiba mais sobre TPUs no GKE.

JAX

O JAX é um framework de machine learning de alto desempenho projetado para ser usado com TPUs e GPUs. O JAX oferece uma API para criar e treinar modelos de machine learning.

Para saber mais, consulte o repositório JAX (em inglês).

Objetivos

Este tutorial inclui as etapas a seguir:

  1. Crie um cluster do Autopilot ou do GKE Standard com a topologia de TPU recomendada com base nas características do modelo. Neste tutorial, você vai fazer o ajuste fino em pools de nós de host único.
  2. Adicione dados a um bucket do Cloud Storage e faça a montagem no contêiner usando o Cloud Storage FUSE.
  3. Implante o job de ajuste fino do LLM no GKE.
  4. Monitore o job de ajuste refinado e confira os registros.

Antes de começar

  • Faça login na sua conta do Google Cloud . Se você começou a usar o Google Cloud, crie uma conta para avaliar o desempenho de nossos produtos em situações reais. Clientes novos também recebem US$ 300 em créditos para executar, testar e implantar cargas de trabalho.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • Verifique se você tem os seguintes papéis no projeto: roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    Verificar os papéis

    1. No console do Google Cloud , acesse a página IAM.

      Acessar IAM
    2. Selecione o projeto.
    3. Na coluna Principal, encontre todas as linhas que identificam você ou um grupo no qual você está incluído. Para saber em quais grupos você está incluído, entre em contato com o administrador.

    4. Em todas as linhas que especificam ou incluem você, verifique a coluna Papel para ver se a lista de papéis inclui os papéis necessários.

    Conceder os papéis

    1. No console do Google Cloud , acesse a página IAM.

      Acessar IAM
    2. Selecione o projeto.
    3. Clique em Conceder acesso.
    4. No campo Novos principais, digite seu identificador de usuário. Normalmente, é o endereço de e-mail de uma Conta do Google.

    5. Clique em Selecionar um papel e pesquise o papel.
    6. Para conceder outros papéis, adicione-os clicando em Adicionar outro papel.
    7. Clique em Salvar.
  • Verifique se você tem cota suficiente para 16 chips TPU Trillium (v6e). Neste tutorial, você vai usar uma configuração de pool de nós que exige 16 chips e instâncias sob demanda.
  • Verifique se você tem um repositório do Docker. Se você não tiver um, crie um repositório padrão no Artifact Registry.

Prepare o ambiente

Neste tutorial, você vai usar o Cloud Shell para gerenciar recursos hospedados em Google Cloud. O Cloud Shell vem pré-instalado com o software necessário para este tutorial, incluindo kubectl e a Google Cloud CLI.

Para configurar o ambiente com o Cloud Shell, siga estas etapas:

  1. No console do Google Cloud , inicie uma sessão do Cloud Shell e clique em Ícone de ativação do Cloud Shell Ativar o Cloud Shell. Essa ação inicia uma sessão no painel inferior do console Google Cloud .

  2. Defina as variáveis de ambiente padrão:

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    Substitua os seguintes valores:

    • PROJECT_ID: o Google Cloud ID do projeto.
    • CLUSTER_NAME: o nome do cluster do GKE.
    • CONTROL_PLANE_LOCATION: a região do Compute Engine em que o cluster do GKE e os nós da TPU estão localizados. A região precisa conter zonas em que os tipos de máquina de TPU Trillium (v6e) estão disponíveis.
    • ZONE: uma zona na região CONTROL_PLANE_LOCATION selecionada em que os tipos de máquina TPU Trillium (v6e) estão disponíveis. Para listar as zonas em que as TPUs Trillium (v6e) estão disponíveis, execute o seguinte comando:

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME: o nome do bucket do Cloud Storage que contém seus dados de treinamento.

  3. Clone o repositório de amostra:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. Navegue até o diretório de trabalho:

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Criar e configurar recursos Google Cloud

Nesta seção, você cria e configura recursos Google Cloud .

Criar um cluster do GKE

É possível ajustar um LLM em TPUs em um cluster do GKE Autopilot ou Standard. Recomendamos que você use um cluster do Autopilot para ter uma experiência totalmente gerenciada do Kubernetes. Para escolher o modo de operação do GKE mais adequado para suas cargas de trabalho, consulte Escolher um modo de operação do GKE.

Piloto automático

Crie um cluster do GKE Autopilot que use a Federação de Identidade da Carga de Trabalho para GKE e tenha o Cloud Storage FUSE ativado.

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

A criação do cluster pode levar vários minutos.

Padrão

  1. Crie um cluster regional do GKE Standard que use a Federação de Identidade da Carga de Trabalho para GKE e tenha o Cloud Storage FUSE ativado.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    A criação do cluster pode levar vários minutos.

  2. Crie um pool de nós de host único:

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

O GKE cria um pool de nós de TPU Trillium com uma topologia 1x1 e um nó. A flag --workload-metadata=GKE_METADATA configura o pool de nós para usar o servidor de metadados do GKE.

Instalar JobSet

  1. Configure kubectl para se comunicar com o cluster:

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. Instale a versão mais recente do JobSet:

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Substitua JOBSET_VERSION pela versão mais recente do JobSet. Por exemplo, v0.11.0.

  3. Verifique a instalação do JobSet:

    kubectl get pods -n jobset-system
    

    O resultado será o seguinte:

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    Talvez seja necessário adicionar mais nós se o JobSet estiver aguardando recursos.

Configurar o Cloud Storage FUSE

Para fazer o ajuste fino do LLM, você precisa fornecer dados de treinamento. Neste tutorial, você vai usar o conjunto de dados TinyStories do Hugging Face. Esse conjunto de dados contém contos gerados sinteticamente pelo GPT-3.5 e pelo GPT-4 que usam um vocabulário limitado.

Nesta seção, abordamos as etapas para configurar o Cloud Storage FUSE e ler dados de um bucket do Cloud Storage.

  1. Faça o download do conjunto de dados:

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. Faça o upload dos dados para um novo bucket do Cloud Storage:

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. Para permitir que sua carga de trabalho leia dados pelo Cloud Storage FUSE, crie uma conta de serviço do Kubernetes (KSA) e adicione as permissões necessárias. Execute o script permissionsetup.sh:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    Depois de executar esse script, os seguintes recursos serão configurados no seu projeto doGoogle Cloud e no cluster do GKE:

    • Uma nova conta de serviço do IAM chamada gcs-fuse-sa é criada no seu projeto.
    • A conta de serviço (GSA) Google Cloud criada (gcs-fuse-sa) recebe o papel roles/storage.objectViewer no bucket do Cloud Storage especificado por ${GCS_BUCKET_NAME}. Essa permissão permite que o GSA leia objetos do bucket.
    • Uma nova KSA chamada jaxserviceaccount é criada no namespace default dentro do cluster do GKE.
    • A política do IAM da GSA é atualizada para conceder o papel roles/iam.workloadIdentityUser à KSA. Essa permissão permite que a KSA se faça passar pela GSA.
    • A KSA é anotada para ser vinculada à GSA. Essa anotação informa ao GKE qual GSA a KSA precisa representar usando a Identidade da carga de trabalho.

      Qualquer pod em execução no namespace default do cluster do GKE que use a conta de serviço jaxserviceaccount agora poderá fazer a autenticação como a GSA gcs-fuse-sa. Esses pods terão acesso de leitura aos objetos armazenados no bucket gs://${GCS_BUCKET_NAME}, o que é essencial para que o job de ajuste fino acesse o conjunto de dados usando o Cloud Storage FUSE.

Criar o script de ajuste refinado

Nesta seção, você vai conhecer o script de treinamento que realiza uma operação de ajuste refinado em um modelo Gemma 3. Esse script usa o Gemma3Tokenizer.

Analise o seguinte script de ajuste refinado Gemma3LLMTrain.py:

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

Neste script, o seguinte se aplica:

  • Um Gemma3Tokenizer converte dados de texto em tokens que o modelo pode processar.
  • A função load_and_preprocess_data lê os dados de treinamento de um arquivo, divide em histórias individuais e usa o tokenizador para converter o texto em sequências de tokens com padding.
  • A função generate_text usa o modelo, os parâmetros dele e um comando para gerar texto.
  • A função train_step define uma única iteração de treinamento que inclui a transmissão direta, o cálculo da perda (usando entropia cruzada), a computação do gradiente e as atualizações de parâmetros.
  • A função train_model itera o conjunto de dados por um número especificado de épocas, que chama a função train_step para cada lote.
  • A função run_training organiza todo o processo para carregar dados, inicializar o modelo Gemma 3 (Gemma3_270M) e o otimizador, carregar parâmetros pré-treinados, configurar o sharding de dados para processamento paralelo, executar uma geração de teste, executar o loop de treinamento e realizar uma geração de texto final para demonstrar o efeito do ajuste fino.
  • O script usa a biblioteca argparse para aceitar argumentos de linha de comando para os parâmetros maxlen, batch_size e datacount.

Agora que você já conhece o script de ajuste refinado, crie um contêiner para executá-lo no GKE.

Colocar o script de ajuste fino em um contêiner

Antes de executar o script de ajuste refinado em um cluster do GKE, é necessário criar um contêiner para ele. Este tutorial usa uma imagem de IA do JAX como imagem de base.

  1. Abra o Dockerfile no mesmo diretório que o arquivo Gemma3LLMTrain.py:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Esse Dockerfile instala as dependências necessárias e copia o arquivo Gemma3LLMTrain.py para o contêiner.

  2. Crie a imagem Docker e envie-a para um repositório de imagens:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Substitua REPOSITORY_NAME pelo nome do repositório do Artifact Registry.

  3. Adicione vinculações de papéis à conta de serviço:

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

Com a imagem no repositório, agora é possível implantar o job de ajuste refinado em um cluster do GKE.

Implantar o job de ajuste fino do LLM

Nesta seção, mostramos como implantar o job de ajuste refinado de LLM no seu cluster do GKE.

  1. Abra o manifesto training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Aplique o manifesto:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

O GKE cria um job que inicia um pod em um nó TPU Trillium (v6e). Esse pod executa o script de ajuste refinado em Python, que acessa os dados de ajuste refinado do bucket especificado do Cloud Storage montado no caminho /data usando o Cloud Storage FUSE. Em seguida, o script ajusta o modelo do Gemma.

Monitorar o job de treinamento

Nesta seção, você vai monitorar o progresso do job de ajuste fino e a performance dele.

Conferir o progresso do ajuste detalhado

  1. Liste os pods:

    # Find the Pods
    kubectl get pods
    
  2. Siga a saída do registro:

    kubectl logs -f pods/POD_NAME
    

    Substitua POD_NAME pelo nome do pod.

    O resultado será o seguinte:

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. Analise a saída:

    • A linha Global device count: 1 indica os núcleos de TPU usados.
    • O modelo gera texto razoável antes dessa execução de ajuste porque ele é carregado de um checkpoint pré-treinado.
    • A saída gerada após o ajuste fino se parece mais com o início de um conto, indicando que o modelo está aprendendo com o novo conjunto de dados.
    • O ajuste detalhado no conjunto de dados completo deve gerar resultados ainda mais refinados.

Observar métricas

Confira o desempenho do job de ajuste refinado verificando as métricas de TPU e CPU. Para conferir as métricas de observabilidade do cluster, siga as etapas em Consultar as métricas de observabilidade de clusters e cargas de trabalho.

Configurações alternativas de ajuste detalhado

Esta seção descreve configurações alternativas para sua carga de trabalho de ajuste refinado.

Seleção de modelos

Este tutorial usou o modelo Gemma3_270M, que é pequeno e se encaixa em um pool de nós de TPU Trillium (v6e) de host único. Para modelos maiores que exigem mais memória e computação para ajuste detalhado, use configurações de pool de nós de vários hosts ou várias frações.

Para conferir uma lista completa dos modelos disponíveis, consulte a documentação da Gemma.

Configurações do pool de nós

Este tutorial usou um pool de nós de host único. Também é possível criar pools de nós de fração de TPU de vários hosts ou pools de nós de várias frações, dependendo das suas necessidades.

As guias a seguir mostram como criar pools de nós de vários hosts e várias frações:

Vários hosts

  1. No Cloud Shell, execute este comando:

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    O GKE cria um pool de nós de TPU Trillium com uma topologia 2x4 e dois nós.

  2. Abra a definição do job training_multihost_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. Implante o job de ajuste:

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

Multislice

  1. No Cloud Shell, execute este comando:

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    O GKE cria dois pools de nós de TPU Trillium. Cada pool de nós tem uma topologia 2x4 e dois nós.

  2. Abra a definição do job training_multislice_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. Implante o job de ajuste:

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

Análise e otimização de performance

Para analisar e otimizar a performance do ajuste fino de machine learning, use o XProf. O XProf é um conjunto de ferramentas que cria perfis e inspeciona cargas de trabalho de ML criadas com JAX, TensorFlow ou PyTorch/XLA. Ao mostrar rastreamentos de execução, uso da memória e outros dados, o XProf permite ajustar seus modelos e a configuração de treinamento para melhorar a eficiência e acelerar o treinamento.

Para analisar a performance da sua carga de trabalho de ajuste refinado usando o XProf, siga estas etapas nesta seção:

  • Instale o pacote xprof. Modifique o script de treinamento para iniciar o servidor XProf.
  • Modifique o manifesto do job do Kubernetes para incluir uma montagem de volume para os registros do XProf.
  • Conceda à conta de serviço permissões para gravar registros do XProf em um bucket do Cloud Storage.
  • Execute o XProf no seu pod e configure o encaminhamento de porta para acessar o painel do XProf.

Instalar o pacote XProf

  1. Navegue até o diretório que contém as amostras do XProf:

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Crie a imagem Docker e envie-a para um repositório de imagens:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Substitua REPOSITORY_NAME pelo nome do repositório do Artifact Registry.

  3. Execute o script Dockerfile:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Esse Dockerfile instala as dependências do XProf.

Copie o script de ajuste refinado para o contêiner.

Nesta seção, crie e aplique um manifesto de job do Kubernetes que inclua as montagens de volume necessárias para os registros do XProf.

  1. Abra a definição do job training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Aplique o manifesto:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

Conceder permissões à conta de serviço para gravar registros do XProf

  1. Para permitir que a conta de serviço grave e leia, adicione o papel "roles/storage.objectUser":

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    Substitua:

    • GSA_NAME: o nome da conta de serviço do Google a que será concedida a função.
    • XPROF_GCS_BUCKET_NAME: o nome do bucket a que o papel será concedido.
  2. Execute o XProf no seu pod:

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    Substitua POD_NAME pelo nome do pod.

Acessar o painel do XProf

  1. Configure o encaminhamento de portas para o servidor XProf no pod:

    kubectl port-forward POD_NAME 9001:9001
    
  2. Na barra de endereço do navegador, digite:

    http://localhost:9001/
    

    O visualizador de rastreamento do XProf é aberto.

  3. Na janela do TensorBoard, clique em Capturar perfil.

  4. No campo URLs de serviço do perfil ou nome da TPU, insira localhost:9002.

  5. Para capturar mais detalhes, em Nível de rastreamento do host (TraceMe), selecione detalhado e ative o registro de rastreamento do Python.

  6. Para acessar o painel, clique em Capturar.

    O TensorBoard captura o perfil e permite analisar o desempenho do script de treinamento. O gráfico mostra a linha do tempo de execução dos perfis de performance da TPU e da CPU:

Um exemplo do visualizador de rastreamento do XProf que mostra um gráfico de matriz de performance

Para mais opções de criação de perfil para analisar o desempenho da sua carga de trabalho de treinamento, consulte a documentação do JAX sobre Criação de perfil de computação.

Ajuste refinado em ambientes de produção

Neste tutorial, mostramos como testar o treinamento baseado em JAX em um ambiente distribuído. Para um ajuste fino otimizado do LLM em produção, use a biblioteca MaxText. Se você tiver interesse em modelos de difusão, use as implementações do Maxdiffusion.

Para cargas de trabalho de treinamento ou ajuste refinado de longa duração em produção, configure a criação de checkpoints de carga de trabalho para minimizar a perda de progresso durante uma falha. Para saber mais sobre a configuração de pontos de verificação de vários níveis, consulte Treinar modelos de machine learning em grande escala no GKE com pontos de verificação de vários níveis.

Limpar

Para evitar cobranças na sua conta do Google Cloud pelos recursos usados no tutorial, exclua o projeto que os contém ou mantenha o projeto e exclua os recursos individuais.

Excluir recursos individuais

Para evitar cobranças na sua conta do Google Cloud pelos recursos usados neste tutorial, exclua o projeto que contém os recursos ou mantenha o projeto e exclua os recursos individuais executando os seguintes comandos:

  1. Exclua os recursos criados neste tutorial:

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. Se você não precisar dos dados gerados pelo XProf, remova o bucket do Cloud Storage usado por ele:

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

A seguir