Perfeziona un LLM utilizzando le TPU su GKE con JAX

Questo tutorial mostra come eseguire il fine tuning di un modello linguistico di grandi dimensioni (LLM) utilizzando le unità di elaborazione tensoriale (TPU) su Google Kubernetes Engine (GKE) con JAX. Il fine-tuning ti consente di adattare un foundation model come Gemma 3 a un'attività o un dominio specifico. Questo processo migliora la precisione e l'accuratezza del modello aggiornandone i parametri con il tuo set di dati specializzato.

Questa guida è un buon punto di partenza se hai bisogno del controllo granulare, della personalizzazione, della scalabilità, della resilienza, della portabilità e dell'economicità di Kubernetes gestito durante la messa a punto dei tuoi workload AI/ML.

Sfondo

Utilizzando le TPU su GKE con Jax per il fine-tuning di un LLM, puoi creare una soluzione di fine-tuning affidabile e pronta per la produzione con tutti i vantaggi di Kubernetes gestito.

Gemma

Gemma è un insieme di modelli multimodali di AI/ML generativa, leggeri e disponibili apertamente, rilasciati con una licenza aperta. Questi modelli di AI sono disponibili per l'esecuzione in applicazioni, hardware, dispositivi mobili o servizi ospitati. Gemma 3 introduce la multimodalità e supporta l'input di visione-linguaggio e gli output di testo. Gestisce finestre contestuali fino a 128.000 token e supporta oltre 140 lingue. Gemma 3 offre anche funzionalità migliorate per matematica, ragionamento e chat, tra cui output strutturati e chiamate di funzioni.

Puoi utilizzare i modelli Gemma per la generazione di testo oppure puoi anche ottimizzarli per attività specializzate.

Per saperne di più, consulta la documentazione di Gemma.

TPU

Le TPU sono circuiti integrati specifici per le applicazioni (ASIC) che Google ha sviluppato su misura per accelerare i modelli di machine learning e AI creati utilizzando framework come TensorFlow, PyTorch e JAX.

Prima di utilizzare le TPU in GKE, ti consigliamo di completare il seguente percorso di apprendimento:

  1. Scopri di più sulla disponibilità della versione TPU attuale con l'architettura di sistema di Cloud TPU.
  2. Scopri di più sulle TPU in GKE.

JAX

JAX è un framework di machine learning ad alte prestazioni progettato per essere utilizzato con TPU e GPU. JAX fornisce un'API per la creazione e l'addestramento di modelli di machine learning.

Per saperne di più, consulta il repository JAX.

Obiettivi

Questo tutorial illustra i seguenti passaggi:

  1. Crea un cluster GKE Autopilot o Standard con la topologia TPU consigliata, in base alle caratteristiche del modello. Durante questo tutorial, esegui il fine tuning sui pool di nodi single-host.
  2. Aggiungi dati a un bucket Cloud Storage e montalo nel container tramite Cloud Storage FUSE.
  3. Esegui il deployment del job di perfezionamento dell'LLM su GKE.
  4. Monitora il job di perfezionamento e visualizza i log.

Prima di iniziare

  • Accedi al tuo account Google Cloud . Se non conosci Google Cloud, crea un account per valutare le prestazioni dei nostri prodotti in scenari reali. I nuovi clienti ricevono anche 300 $di crediti senza costi per l'esecuzione, il test e il deployment dei workload.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • Assicurati di disporre dei seguenti ruoli nel progetto: roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    Controlla i ruoli

    1. Nella console Google Cloud vai alla pagina IAM.

      Vai a IAM
    2. Seleziona il progetto.
    3. Nella colonna Entità, trova tutte le righe che identificano te o un gruppo di cui fai parte. Per scoprire a quali gruppi appartieni, contatta il tuo amministratore.

    4. Per tutte le righe che ti specificano o ti includono, controlla la colonna Ruolo per verificare se l'elenco dei ruoli include i ruoli richiesti.

    Concedi i ruoli

    1. Nella console Google Cloud vai alla pagina IAM.

      Vai a IAM
    2. Seleziona il progetto.
    3. Fai clic su Concedi l'accesso.
    4. Nel campo Nuove entità, inserisci il tuo identificatore dell'utente. In genere si tratta dell'indirizzo email di un Account Google.

    5. Fai clic su Seleziona un ruolo, quindi cerca il ruolo.
    6. Per concedere altri ruoli, fai clic su Aggiungi un altro ruolo e aggiungi ogni ruolo successivo.
    7. Fai clic su Salva.
  • Assicurati di disporre di una quota sufficiente per 16 chip TPU Trillium (v6e). In questo tutorial utilizzi una configurazione pool di nodi che richiede 16 chip e istanze on demand.
  • Assicurati di avere un repository Docker. Se non ne hai uno, crea un repository standard in Artifact Registry.

Prepara l'ambiente

In questo tutorial utilizzerai Cloud Shell per gestire le risorse ospitate su Google Cloud. Cloud Shell è preinstallato con il software necessario per questo tutorial, tra cui kubectl e Google Cloud CLI.

Per configurare l'ambiente con Cloud Shell:

  1. Nella console Google Cloud , avvia una sessione Cloud Shell e fai clic su Icona di attivazione di Cloud Shell Attiva Cloud Shell. Questa azione avvia una sessione nel riquadro inferiore della console Google Cloud .

  2. Imposta le variabili di ambiente predefinite:

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    Sostituisci i seguenti valori:

    • PROJECT_ID: il tuo Google Cloud ID progetto.
    • CLUSTER_NAME: il nome del tuo cluster GKE.
    • CONTROL_PLANE_LOCATION: la regione di Compute Engine in cui si trovano il cluster GKE e i nodi TPU. La regione deve contenere zone in cui sono disponibili i tipi di macchina TPU Trillium (v6e).
    • ZONE: una zona all'interno della regione CONTROL_PLANE_LOCATION selezionata in cui sono disponibili i tipi di macchine TPU Trillium (v6e). Per elencare le zone in cui sono disponibili le TPU TPU Trillium (v6e), esegui il comando seguente:

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME: il nome del bucket Cloud Storage che contiene i dati di addestramento.

  3. Clona il repository di esempio:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. Vai alla directory di lavoro:

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Creare e configurare risorse Google Cloud

In questa sezione, creerai e configurerai le risorse Google Cloud .

Crea un cluster GKE

Puoi eseguire il fine tuning di un LLM sulle TPU in un cluster GKE Autopilot o Standard. Ti consigliamo di utilizzare un cluster Autopilot per un'esperienza Kubernetes completamente gestita. Per scegliere la modalità operativa GKE più adatta ai tuoi workload, consulta Scegliere una modalità operativa GKE.

Autopilot

Crea un cluster GKE Autopilot che utilizzi Workload Identity Federation for GKE e in cui sia abilitato Cloud Storage FUSE.

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

La creazione del cluster potrebbe richiedere diversi minuti.

Standard

  1. Crea un cluster GKE Standard regionale che utilizzi Workload Identity Federation for GKE e in cui sia abilitato Cloud Storage FUSE.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    La creazione del cluster potrebbe richiedere diversi minuti.

  2. Crea un pool di nodi a singolo host:

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

GKE crea un pool di nodi TPU Trillium con una topologia 1x1 e un nodo. Il flag --workload-metadata=GKE_METADATA configura il pool di nodi in modo che utilizzi il server metadati GKE.

Installare JobSet

  1. Configura kubectl per comunicare con il cluster:

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. Installa l'ultima versione rilasciata di JobSet:

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Sostituisci JOBSET_VERSION con l'ultima versione rilasciata di JobSet. Ad esempio: v0.11.0.

  3. Verifica l'installazione di JobSet:

    kubectl get pods -n jobset-system
    

    L'output è simile al seguente:

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    Potresti dover aggiungere altri nodi se JobSet è in attesa di risorse.

Configurare Cloud Storage FUSE

Per ottimizzare l'LLM, devi fornire dati di addestramento. In questo tutorial, utilizzerai il set di dati TinyStories di Hugging Face. Questo set di dati contiene racconti brevi generati sinteticamente da GPT-3.5 e GPT-4, che utilizzano un vocabolario limitato.

Questa sezione descrive i passaggi per configurare Cloud Storage FUSE in modo da leggere i dati da un bucket Cloud Storage.

  1. Scarica il set di dati:

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. Carica i dati in un nuovo bucket Cloud Storage:

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. Per consentire al tuo workload di leggere i dati tramite Cloud Storage FUSE, crea un account di servizio Kubernetes (KSA) e aggiungi le autorizzazioni richieste. Esegui lo script permissionsetup.sh:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    Dopo aver eseguito questo script, le seguenti risorse vengono configurate nel tuo progettoGoogle Cloud e nel cluster GKE:

    • Nel tuo progetto viene creato un nuovo account di servizio IAM denominato gcs-fuse-sa.
    • Al service account Google (GSA) creato Google Cloud (gcs-fuse-sa) viene concesso il ruolo roles/storage.objectViewer sul bucket Cloud Storage specificato da ${GCS_BUCKET_NAME}. Questa autorizzazione consente al GSA di leggere gli oggetti dal bucket.
    • Nel tuo cluster GKE viene creato un nuovo KSA denominato jaxserviceaccount nello spazio dei nomi default.
    • Il criterio IAM del service account Google è aggiornato per concedere il ruolo roles/iam.workloadIdentityUser al service account Kubernetes. Questa autorizzazione consente al KSA di rappresentare il GSA.
    • Il KSA è annotato per essere collegato al GSA. Questa annotazione indica a GKE quale account di servizio Google deve essere rappresentato dal KSA utilizzando Workload Identity.

      Qualsiasi pod in esecuzione nello spazio dei nomi default del cluster GKE che utilizza il account di servizio jaxserviceaccount ora potrà autenticarsi come GSA gcs-fuse-sa. Questi pod avranno accesso in lettura agli oggetti archiviati nel bucket gs://${GCS_BUCKET_NAME}, il che è essenziale per consentire al job di perfezionamento di accedere al set di dati utilizzando Cloud Storage FUSE.

Crea lo script di ottimizzazione

In questa sezione, esplorerai lo script di addestramento che esegue un'operazione di fine tuning su un modello Gemma 3. Questo script utilizza Gemma3Tokenizer.

Esamina il seguente script di Gemma3LLMTrain.py ottimizzazione:

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

In questo script, vale quanto segue:

  • Un Gemma3Tokenizer converte i dati di testo in token che il modello può elaborare.
  • La funzione load_and_preprocess_data legge i dati di addestramento da un file, li suddivide in singole storie e utilizza il tokenizer per convertire il testo in sequenze di token con padding.
  • La funzione generate_text accetta il modello, i relativi parametri e un prompt per generare il testo.
  • La funzione train_step definisce una singola iterazione di addestramento che include il passaggio in avanti, il calcolo della perdita (utilizzando l'entropia incrociata), il calcolo del gradiente e gli aggiornamenti dei parametri.
  • La funzione train_model scorre il set di dati per un numero specificato di epoche, che chiama la funzione train_step per ogni batch.
  • La funzione run_training orchestra l'intero processo di caricamento dei dati, inizializzazione del modello Gemma 3 (Gemma3_270M) e dell'ottimizzatore, caricamento dei parametri preaddestrati, configurazione dello sharding dei dati per l'elaborazione parallela, esecuzione di una generazione di test, esecuzione del ciclo di addestramento ed esecuzione di una generazione di testo finale per dimostrare l'effetto del fine tuning.
  • Lo script utilizza la libreria argparse per accettare gli argomenti della riga di comando per i parametri maxlen, batch_size e datacount.

Ora che hai esplorato lo script di perfezionamento, inseriscilo in un container per eseguirlo su GKE.

Containerizza lo script di ottimizzazione

Prima di eseguire lo script di perfezionamento in un cluster GKE, devi inserirlo in un container. Questo tutorial utilizza un'immagine AI JAX come immagine di base.

  1. Apri Dockerfile nella stessa directory del file Gemma3LLMTrain.py:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Questo Dockerfile installa le dipendenze necessarie e copia il file Gemma3LLMTrain.py nel container.

  2. Crea l'immagine Docker ed eseguine il push in un repository di immagini:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Sostituisci REPOSITORY_NAME con il nome del tuo repository Artifact Registry.

  3. Aggiungi associazioni di ruoli al account di servizio:

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

Con l'immagine nel repository, ora puoi eseguire il deployment del job di perfezionamento in un cluster GKE.

Esegui il deployment del job di ottimizzazione dell'LLM

Questa sezione mostra come eseguire il deployment del job di perfezionamento dell'LLM nel cluster GKE.

  1. Apri il file manifest training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Applica il manifest:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

GKE crea un job che avvia un pod su un nodo TPU Trillium (v6e). Questo pod esegue lo script di perfezionamento Python, che accede ai dati di perfezionamento dal bucket Cloud Storage specificato montato nel percorso /data utilizzando Cloud Storage FUSE. Lo script esegue quindi il fine tuning del modello Gemma.

Monitora il job di addestramento

In questa sezione, monitori l'avanzamento del job di perfezionamento e le relative prestazioni.

Visualizzare l'avanzamento dell'ottimizzazione

  1. Elenca i pod:

    # Find the Pods
    kubectl get pods
    
  2. Segui l'output del log:

    kubectl logs -f pods/POD_NAME
    

    Sostituisci POD_NAME con il nome del tuo pod.

    L'output è simile al seguente:

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. Analizza l'output:

    • La riga Global device count: 1 indica i core TPU utilizzati.
    • Il modello genera un testo ragionevole prima di questa esecuzione di ottimizzazione perché carica da un checkpoint preaddestrato.
    • L'output generato dopo l'ottimizzazione mostra una maggiore somiglianza con l'inizio di un racconto, il che indica che il modello sta apprendendo dal nuovo set di dati.
    • Il perfezionamento sull'intero set di dati dovrebbe produrre risultati ancora più raffinati.

Osservare le metriche

Visualizza il rendimento del job di perfezionamento controllando le metriche di TPU e CPU. Per visualizzare le metriche di osservabilità per il cluster, segui i passaggi descritti in Visualizzare le metriche di osservabilità di cluster e carichi di lavoro.

Configurazioni di ottimizzazione alternative

Questa sezione descrive configurazioni alternative per il tuo workload di fine tuning.

Selezione del modello

Questo tutorial ha utilizzato il modello Gemma3_270M, un modello piccolo che rientra in un pool di nodi TPU Trillium (v6e) a host singolo. Per i modelli più grandi che richiedono più memoria e risorse di calcolo per il perfezionamento, puoi utilizzare configurazioni di pool di nodi multihost o multislice.

Per un elenco completo dei modelli disponibili, consulta la documentazione di Gemma.

Configurazioni del node pool

Questo tutorial ha utilizzato un pool di nodi single-host. Puoi anche creare node pool TPU multi-host o node pool multislice, a seconda delle tue esigenze.

Le seguenti schede mostrano come creare pool di nodi multihost e multislices:

Multi-host

  1. In Cloud Shell, esegui questo comando:

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    GKE crea un pool di nodi TPU Trillium con una topologia 2x4 e due nodi.

  2. Apri la definizione del job training_multihost_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. Esegui il deployment del job di ottimizzazione:

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

Multislice

  1. In Cloud Shell, esegui questo comando:

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    GKE crea due node pool TPU Trillium. Ogni node pool ha una topologia 2x4 e due nodi.

  2. Apri la definizione del job training_multislice_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. Esegui il deployment del job di ottimizzazione:

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

Analisi e ottimizzazione del rendimento

Per analizzare e ottimizzare il rendimento della messa a punto del machine learning, puoi utilizzare XProf. XProf è una suite di strumenti che profila e ispeziona i carichi di lavoro ML creati con JAX, TensorFlow o PyTorch/XLA. Mostrando le tracce di esecuzione, l'utilizzo della memoria e altri dati, XProf ti consente di ottimizzare i modelli e la configurazione dell'addestramento per una maggiore efficienza e un addestramento più rapido.

Per analizzare le prestazioni del carico di lavoro di fine tuning utilizzando XProf, completa i seguenti passaggi in questa sezione:

  • Installa il pacchetto xprof. Modifica lo script di addestramento per avviare il server XProf.
  • Modifica il manifest del job Kubernetes in modo da includere un montaggio del volume per i log XProf.
  • Concedi al account di servizio le autorizzazioni per scrivere i log XProf in un bucket Cloud Storage.
  • Esegui XProf all'interno del pod e configura l'inoltro delle porte per accedere alla dashboard XProf.

Installa il pacchetto XProf

  1. Vai alla directory che contiene gli esempi di XProf:

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Crea l'immagine Docker ed eseguine il push in un repository di immagini:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Sostituisci REPOSITORY_NAME con il nome del tuo repository Artifact Registry.

  3. Esegui lo script Dockerfile:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Questo Dockerfile installa le dipendenze di XProf.

Copia lo script di perfezionamento nel container.

In questa sezione, crea e applica un manifest di Kubernetes Job che includa i mount del volume necessari per i log XProf.

  1. Apri la definizione del job training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Applica il manifest:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

Concedi all'account di servizio le autorizzazioni per scrivere i log XProf

  1. Per consentire all'account di servizio di scrivere e leggere, aggiungi il ruolo "roles/storage.objectUser":

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    Sostituisci quanto segue:

    • GSA_NAME: il nome del service account Google a cui concedere il ruolo.
    • XPROF_GCS_BUCKET_NAME: il nome del bucket a cui concedere il ruolo.
  2. Esegui XProf all'interno del pod:

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    Sostituisci POD_NAME con il nome del tuo pod.

Accedere alla dashboard XProf

  1. Configura port forwarding al server XProf nel pod:

    kubectl port-forward POD_NAME 9001:9001
    
  2. Nella barra degli indirizzi del browser, inserisci quanto segue:

    http://localhost:9001/
    

    Si apre XProf Trace Viewer.

  3. Nella finestra TensorBoard, fai clic su Acquisisci profilo.

  4. Nel campo URL dei servizi del profilo o nome TPU, inserisci localhost:9002.

  5. Per acquisire maggiori dettagli, in Host Trace (TraceMe) Level, seleziona verbose e attiva la registrazione delle tracce Python.

  6. Per visualizzare la dashboard, fai clic su Acquisizione.

    TensorBoard acquisisce il profilo e ti consente di analizzare le prestazioni dello script di addestramento. Il grafico mostra la sequenza temporale di esecuzione per i profili di rendimento di TPU e CPU:

Esempio del visualizzatore di tracce XProf che mostra un grafico della matrice di rendimento

Per ulteriori opzioni di profilazione per analizzare le prestazioni del workload di addestramento, consulta la documentazione di JAX sulla profilazione del calcolo.

Ottimizzazione negli ambienti di produzione

Questo tutorial ti ha mostrato come testare l'addestramento basato su JAX in un ambiente distribuito. Per la messa a punto ottimizzata degli LLM in produzione, utilizza la libreria Maxtext. Se ti interessano i modelli di diffusione, utilizza le implementazioni Maxdiffusion.

Per i workload di addestramento o messa a punto di lunga durata in produzione, configura il checkpointing del workload per ridurre al minimo la perdita di avanzamento in caso di errore. Per saperne di più sulla configurazione del checkpointing multilivello, consulta Addestra modelli di machine learning su larga scala su GKE con il checkpointing multilivello.

Esegui la pulizia

Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.

Elimina le singole risorse

Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse eseguendo i seguenti comandi:

  1. Elimina le risorse che hai creato in questo tutorial:

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. Se non hai bisogno dei dati generati da XProf, rimuovi il bucket Cloud Storage utilizzato da XProf:

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

Passaggi successivi