Ajusta un LLM con TPU en GKE con JAX

En este instructivo, se muestra cómo ajustar un modelo de lenguaje grande (LLM) con unidades de procesamiento tensorial (TPU) en Google Kubernetes Engine (GKE) con JAX. El ajuste fino te permite adaptar un modelo de base, como Gemma 3, a un dominio o una tarea específicos. Este proceso mejora la precisión y la exactitud del modelo, ya que actualiza sus parámetros con tu propio conjunto de datos especializado.

Esta guía es un buen punto de partida si necesitas el control detallado, la personalización, la escalabilidad, la resiliencia, la portabilidad y la rentabilidad de Kubernetes administrado cuando ajustas tus cargas de trabajo de IA/AA.

Fondo

Si usas TPU en GKE con JAX para ajustar un LLM, puedes compilar una solución de ajuste robusta y lista para producción con todos los beneficios de Kubernetes administrado.

Gemma

Gemma es un conjunto de modelos multimodales de IA/AA generativa básicos y de disponibilidad general que se lanzan con una licencia abierta. Estos modelos de IA están disponibles para ejecutarse en tus aplicaciones, hardware, dispositivos móviles o servicios alojados. Gemma 3 introduce la multimodalidad y admite entradas de lenguaje visual y salidas de texto. Maneja ventanas de contexto de hasta 128,000 tokens y admite más de 140 idiomas. Gemma 3 también ofrece capacidades mejoradas de matemáticas, razonamiento y chat, incluidas salidas estructuradas y llamadas a funciones.

Puedes usar los modelos de Gemma para la generación de texto, pero también puedes ajustar estos modelos en el caso de tareas especializadas.

Para obtener más información, consulta la documentación de Gemma.

TPU

Las TPU son circuitos integrados específicos de aplicaciones (ASIC) que Google desarrolló de forma personalizada para acelerar el aprendizaje automático y los modelos de IA que se compilan con frameworks como TensorFlow, PyTorch y JAX.

Antes de usar las TPU en GKE, te recomendamos que completes la siguiente ruta de aprendizaje:

  1. Obtén información sobre la disponibilidad actual de la versión de TPU con la arquitectura del sistema de Cloud TPU.
  2. Obtén más información sobre las TPU en GKE.

JAX

JAX es un framework de aprendizaje automático de alto rendimiento diseñado para usarse con TPU y GPU. JAX proporciona una API para compilar y entrenar modelos de aprendizaje automático.

Para obtener más información, consulta el repositorio de JAX.

Objetivos

En este instructivo, se abarcan los siguientes pasos:

  1. Crea un clúster de GKE en modo Autopilot o Standard con la topología de TPU recomendada según las características del modelo. Durante este instructivo, realizarás el ajuste de los grupos de nodos de un solo host.
  2. Agrega datos a un bucket de Cloud Storage y súbelos al contenedor a través de Cloud Storage FUSE.
  3. Implementa el trabajo de ajuste del LLM en GKE.
  4. Supervisa el trabajo de ajuste y visualiza los registros.

Antes de comenzar

  • Accede a tu cuenta de Google Cloud . Si eres nuevo en Google Cloud, crea una cuenta para evaluar el rendimiento de nuestros productos en situaciones reales. Los clientes nuevos también obtienen $300 en créditos gratuitos para ejecutar, probar y, además, implementar cargas de trabajo.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • Asegúrate de tener los siguientes roles en el proyecto: roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    Verifica los roles

    1. En la consola de Google Cloud , dirígete a la página IAM.

      Ir a IAM
    2. Selecciona el proyecto.
    3. En la columna Principal, busca todas las filas que te identifiquen a ti o a un grupo en el que se te incluya. Para saber en qué grupos estás incluido, comunícate con tu administrador.

    4. Para todas las filas en las que se te especifique o se te incluya, verifica la columna Rol para ver si la lista de roles incluye los roles necesarios.

    Otorga los roles

    1. En la consola de Google Cloud , dirígete a la página IAM.

      Ir a IAM
    2. Selecciona el proyecto.
    3. Haz clic en Otorgar acceso.
    4. En el campo Principales nuevas, ingresa tu identificador de usuario. Esta suele ser la dirección de correo electrónico de una Cuenta de Google.

    5. Haz clic en Seleccionar un rol y, luego, busca el rol.
    6. Para otorgar roles adicionales, haz clic en Agregar otro rol y agrega uno más.
    7. Haz clic en Guardar.
  • Asegúrate de tener suficiente cuota para 16 chips de TPU Trillium (v6e). En este instructivo, usarás una configuración de grupo de nodos que requiere 16 chips y instancias bajo demanda.
  • Asegúrate de tener un repositorio de Docker. Si no tienes uno, crea un repositorio estándar en Artifact Registry.

Prepare el entorno

En este instructivo, usarás Cloud Shell para administrar recursos alojados en Google Cloud. Cloud Shell tiene preinstalado el software que necesitas para este instructivo, incluidos kubectl y la Google Cloud CLI.

Para configurar tu entorno con Cloud Shell, sigue estos pasos:

  1. En la Google Cloud consola, inicia una sesión de Cloud Shell y haz clic en Ícono de activación de Cloud Shell Activar Cloud Shell. Esta acción inicia una sesión en el panel inferior de la consola de Google Cloud .

  2. Configura las variables de entorno predeterminadas:

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    Reemplaza los siguientes valores:

    • PROJECT_ID: Es el Google Cloud ID del proyecto.
    • CLUSTER_NAME: Es el nombre del clúster de GKE.
    • CONTROL_PLANE_LOCATION: Es la región de Compute Engine en la que se encuentran tu clúster de GKE y los nodos TPU. La región debe contener zonas en las que estén disponibles los tipos de máquinas de TPU Trillium (v6e).
    • ZONE: Es una zona dentro de la región CONTROL_PLANE_LOCATION seleccionada en la que están disponibles los tipos de máquinas de TPU Trillium (v6e). Para enumerar las zonas en las que hay TPU Trillium (v6e) disponibles, ejecuta el siguiente comando:

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME: Es el nombre del bucket de Cloud Storage que contiene tus datos de entrenamiento.

  3. Clona el repositorio de ejemplo:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. Navega hasta el directorio de trabajo:

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Crea y configura recursos de Google Cloud

En esta sección, crearás y configurarás recursos de Google Cloud .

Crea un clúster de GKE

Puedes ajustar un LLM en TPU en un clúster de GKE Autopilot o Standard. Te recomendamos que uses un clúster de Autopilot para una experiencia de Kubernetes completamente administrada. Para elegir el modo de operación de GKE que se adapte mejor a tus cargas de trabajo, consulta Elige un modo de operación de GKE.

Autopilot

Crea un clúster de GKE Autopilot que use Workload Identity Federation para GKE y tenga habilitado Cloud Storage FUSE.

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

La creación del clúster puede tomar varios minutos.

Estándar

  1. Crea un clúster de GKE Estándar regional que use Workload Identity Federation for GKE y tenga habilitado Cloud Storage FUSE.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    La creación del clúster puede tomar varios minutos.

  2. Crea un grupo de nodos de host único:

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

GKE crea un grupo de nodos TPU Trillium con una topología 1x1 y un nodo. La marca --workload-metadata=GKE_METADATA configura el grupo de nodos para usar el servidor de metadatos de GKE.

Instala JobSet

  1. Configura kubectl para comunicarse con tu clúster:

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. Instala la versión más reciente de JobSet:

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Reemplaza JOBSET_VERSION por la versión más reciente de JobSet. Por ejemplo, v0.11.0.

  3. Verifica la instalación de JobSet:

    kubectl get pods -n jobset-system
    

    El resultado es similar a lo siguiente:

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    Es posible que debas agregar más nodos si JobSet está esperando recursos.

Configura Cloud Storage FUSE

Para ajustar el LLM, debes proporcionar datos de entrenamiento. En este instructivo, usarás el conjunto de datos TinyStories de Hugging Face. Este conjunto de datos contiene cuentos cortos generados de forma sintética por GPT-3.5 y GPT-4 que usan un vocabulario limitado.

En esta sección, se describen los pasos para configurar Cloud Storage FUSE de modo que lea datos de un bucket de Cloud Storage.

  1. Descarga el conjunto de datos:

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. Sube los datos a un bucket nuevo de Cloud Storage:

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. Para permitir que tu carga de trabajo lea datos a través de Cloud Storage FUSE, crea una cuenta de servicio de Kubernetes (KSA) y agrega los permisos necesarios. Ejecuta la secuencia de comandos permissionsetup.sh:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    Después de ejecutar esta secuencia de comandos, se configuran los siguientes recursos en tu proyectoGoogle Cloud y clúster de GKE:

    • Se crea una nueva cuenta de servicio de IAM llamada gcs-fuse-sa en tu proyecto.
    • A la cuenta de servicio (GSA) creada Google Cloud (gcs-fuse-sa) se le otorga el rol de roles/storage.objectViewer en el bucket de Cloud Storage especificado por ${GCS_BUCKET_NAME}. Este permiso permite que la GSA lea objetos del bucket.
    • Se crea un nuevo KSA llamado jaxserviceaccount en el espacio de nombres default dentro de tu clúster de GKE.
    • Se actualiza la política de IAM de la GSA para otorgar el rol roles/iam.workloadIdentityUser a la KSA. Este permiso permite que la KSA actúe en nombre de la GSA.
    • La KSA se anota para vincularla a la GSA. Esta anotación le indica a GKE qué GSA debe suplantar la KSA con Workload Identity.

      Cualquier Pod que se ejecute en el espacio de nombres default de tu clúster de GKE que use la cuenta de servicio jaxserviceaccount ahora podrá autenticarse como la GSA gcs-fuse-sa. Estos Pods tendrán acceso de lectura a los objetos almacenados en el bucket gs://${GCS_BUCKET_NAME}, lo que es fundamental para que el trabajo de ajuste pueda acceder al conjunto de datos con Cloud Storage FUSE.

Crea la secuencia de comandos de ajuste

En esta sección, explorarás la secuencia de comandos de entrenamiento que realiza una operación de ajuste de un modelo de Gemma 3. Esta secuencia de comandos usa Gemma3Tokenizer.

Revisa la siguiente secuencia de comandos de ajuste de Gemma3LLMTrain.py:

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

En esta secuencia de comandos, se aplica lo siguiente:

  • Un Gemma3Tokenizer convierte los datos de texto en tokens que el modelo puede procesar.
  • La función load_and_preprocess_data lee los datos de entrenamiento de un archivo, los divide en historias individuales y usa el tokenizador para convertir el texto en secuencias de tokens con padding.
  • La función generate_text toma el modelo, sus parámetros y una instrucción para generar texto.
  • La función train_step define una sola iteración de entrenamiento que incluye el pase hacia adelante, el cálculo de la pérdida (con entropía cruzada), el cálculo del gradiente y las actualizaciones de los parámetros.
  • La función train_model itera el conjunto de datos durante una cantidad especificada de épocas, lo que llama a la función train_step para cada lote.
  • La función run_training coordina todo el proceso para cargar datos, inicializar el modelo Gemma 3 (Gemma3_270M) y el optimizador, cargar parámetros previamente entrenados, configurar la fragmentación de datos para el procesamiento paralelo, ejecutar una generación de prueba, ejecutar el bucle de entrenamiento y realizar una generación de texto final para demostrar el efecto del ajuste.
  • La secuencia de comandos usa la biblioteca argparse para aceptar argumentos de línea de comandos para los parámetros maxlen, batch_size y datacount.

Ahora que exploraste la secuencia de comandos de ajuste, crea un contenedor para ejecutarla en GKE.

Crea un contenedor para la secuencia de comandos de ajuste

Antes de ejecutar la secuencia de comandos de ajuste en un clúster de GKE, debes contenerla. En este instructivo, se usa una imagen generada por IA de JAX como imagen base.

  1. Abre el archivo Dockerfile en el mismo directorio que el archivo Gemma3LLMTrain.py:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Este Dockerfile instala las dependencias necesarias y copia el archivo Gemma3LLMTrain.py en el contenedor.

  2. Compila la imagen de Docker y envíala a un repositorio de imágenes:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Reemplaza REPOSITORY_NAME por el nombre de tu repositorio de Artifact Registry.

  3. Agrega vinculaciones de roles a la cuenta de servicio:

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

Con la imagen en el repositorio, ahora puedes implementar el trabajo de ajuste en un clúster de GKE.

Implementa el trabajo de ajuste del LLM

En esta sección, se muestra cómo implementar el trabajo de ajuste de LLM en tu clúster de GKE.

  1. Abre el manifiesto training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Aplica el manifiesto

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

GKE crea un trabajo que inicia un Pod en un nodo de TPU Trillium (v6e). Este Pod ejecuta la secuencia de comandos de ajuste fino de Python, que accede a los datos de ajuste fino desde el bucket de Cloud Storage especificado que se encuentra montado en la ruta de acceso /data con Cloud Storage FUSE. Luego, la secuencia de comandos ajusta el modelo de Gemma.

Supervisa el trabajo de entrenamiento

En esta sección, supervisarás el progreso del trabajo de ajuste y su rendimiento.

Cómo ver el progreso del ajuste

  1. Enumera los Pods:

    # Find the Pods
    kubectl get pods
    
  2. Sigue el resultado del registro:

    kubectl logs -f pods/POD_NAME
    

    Reemplaza POD_NAME por el nombre del Pod.

    El resultado es similar a lo siguiente:

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. Analiza el resultado:

    • La línea Global device count: 1 indica los núcleos de TPU que se usaron.
    • El modelo genera texto razonable antes de esta ejecución de ajuste porque se carga desde un punto de control ya entrenado.
    • El resultado generado después del ajuste fino muestra más similitud con el comienzo de un cuento, lo que indica que el modelo está aprendiendo del nuevo conjunto de datos.
    • El ajuste en el conjunto de datos completo debería producir resultados aún más refinados.

Observa las métricas

Consulta el rendimiento del trabajo de ajuste verificando las métricas de la TPU y la CPU. Para ver las métricas de observabilidad de tu clúster, sigue los pasos que se indican en Visualiza las métricas de observabilidad de clústeres y cargas de trabajo.

Configuraciones alternativas de ajuste

En esta sección, se describen configuraciones alternativas para tu carga de trabajo de ajuste.

Selección del modelo

En este instructivo, se usó el modelo Gemma3_270M, que es un modelo pequeño que cabe en un grupo de nodos de TPU Trillium (v6e) de host único. Para los modelos más grandes que requieren más memoria y capacidad de procesamiento para el ajuste, puedes usar configuraciones de grupos de nodos de varios hosts o de varias porciones.

Para obtener una lista completa de los modelos disponibles, consulta la documentación de Gemma.

Configuraciones de grupos de nodos

En este instructivo, se usó un grupo de nodos de host único. También puedes crear grupos de nodos de porciones de TPU de varios hosts o grupos de nodos de multislice, según tus necesidades.

En las siguientes pestañas, se muestra cómo crear grupos de nodos de varios hosts y de varias porciones:

Varios hosts

  1. En Cloud Shell, ejecuta el siguiente comando:

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    GKE crea un grupo de nodos TPU Trillium con una topología 2x4 y dos nodos.

  2. Abre la definición del trabajo training_multihost_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. Implementa el trabajo de ajuste:

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

Porciones múltiples

  1. En Cloud Shell, ejecuta el siguiente comando:

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    GKE crea dos grupos de nodos TPU Trillium. Cada grupo de nodos tiene una topología 2x4 y dos nodos.

  2. Abre la definición del trabajo training_multislice_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. Implementa el trabajo de ajuste:

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

Análisis y optimización del rendimiento

Para analizar y optimizar el rendimiento del ajuste de tu modelo de aprendizaje automático, puedes usar XProf. XProf es un conjunto de herramientas que genera perfiles y analiza las cargas de trabajo de AA creadas con JAX, TensorFlow o PyTorch/XLA. Al mostrar los registros de ejecución, el uso de memoria y otros datos, XProf te permite ajustar tus modelos y la configuración del entrenamiento para lograr una mayor eficiencia y un entrenamiento más rápido.

Para analizar el rendimiento de tu carga de trabajo de ajuste con XProf, completa los siguientes pasos en esta sección:

  • Instala el paquete xprof. Modifica tu secuencia de comandos de entrenamiento para iniciar el servidor de XProf.
  • Modifica el manifiesto de tu trabajo de Kubernetes para incluir una activación de volumen para los registros de XProf.
  • Otorga permisos a la cuenta de servicio para escribir registros de XProf en un bucket de Cloud Storage.
  • Ejecuta XProf dentro de tu Pod y configura la redirección de puertos para acceder al panel de XProf.

Instala el paquete de XProf

  1. Navega al directorio que contiene las muestras de XProf:

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Compila la imagen de Docker y envíala a un repositorio de imágenes:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Reemplaza REPOSITORY_NAME por el nombre de tu repositorio de Artifact Registry.

  3. Ejecuta la secuencia de comandos Dockerfile:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Este Dockerfile instala las dependencias de XProf.

Copia tu secuencia de comandos de ajuste fino en el contenedor

En esta sección, crearás y aplicarás un manifiesto de trabajo de Kubernetes que incluya los activadores de volúmenes necesarios para los registros de XProf.

  1. Abre la definición del trabajo training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Aplica el manifiesto

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

Otorga permisos a la cuenta de servicio para escribir registros de XProf

  1. Para permitir que la cuenta de servicio escriba y lea, agrega el rol "roles/storage.objectUser":

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    Reemplaza lo siguiente:

    • GSA_NAME: Es el nombre de la cuenta de servicio de Google a la que se le otorgará el rol.
    • XPROF_GCS_BUCKET_NAME: Es el nombre del bucket al que se le otorgará el rol.
  2. Ejecuta XProf dentro de tu Pod:

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    Reemplaza POD_NAME por el nombre del Pod.

Accede al panel de XProf

  1. Configura la redirección de puertos al servidor de XProf en el Pod:

    kubectl port-forward POD_NAME 9001:9001
    
  2. En la barra de direcciones del navegador, ingresa lo siguiente:

    http://localhost:9001/
    

    Se abrirá el Visor de registros de XProf.

  3. En la ventana de TensorBoard, haz clic en Capture profile.

  4. En el campo URLs del servicio de perfil o nombre de TPU, ingresa localhost:9002.

  5. Para capturar más detalles, en el nivel de registro de Host Trace (TraceMe), selecciona verbose y habilita el registro de seguimiento de Python.

  6. Para ver el panel, haz clic en Capturar.

    TensorBoard captura el perfil y te permite analizar el rendimiento de la secuencia de comandos de entrenamiento. En el gráfico, se muestra la línea de tiempo de ejecución de los perfiles de rendimiento de la CPU y la TPU:

Un ejemplo del lector de registros de XProf que muestra un gráfico de matriz de rendimiento

Para obtener más opciones de generación de perfiles para analizar el rendimiento de tu carga de trabajo de entrenamiento, consulta la documentación de JAX sobre Generación de perfiles de procesamiento.

Ajuste en entornos de producción

En este instructivo, se mostró cómo probar el entrenamiento basado en JAX en un entorno distribuido. Para optimizar el ajuste del LLM en producción, usa la biblioteca Maxtext. Si te interesan los modelos de difusión, usa las implementaciones de Maxdiffusion.

Para las cargas de trabajo de entrenamiento o ajuste de larga duración en producción, configura puntos de control de la carga de trabajo para minimizar la pérdida de progreso durante una falla. Si deseas obtener más información para configurar el registro de puntos de control de varios niveles, consulta Cómo entrenar modelos de aprendizaje automático a gran escala en GKE con el registro de puntos de control de varios niveles.

Realiza una limpieza

Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

Borra los recursos individuales

Para evitar que se apliquen cargos a tu Google Cloud cuenta por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales ejecutando los siguientes comandos:

  1. Borra los recursos que creaste en este instructivo:

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. Si no necesitas los datos que genera XProf, quita el bucket de Cloud Storage que usa XProf:

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

¿Qué sigue?