LLM mithilfe von TPUs in GKE mit JAX optimieren

In dieser Anleitung wird beschrieben, wie Sie ein Large Language Model (LLM) mit Tensor Processing Units (TPUs) in Google Kubernetes Engine (GKE) mit JAX abstimmen. Mit der Feinabstimmung können Sie ein Foundation Model wie Gemma 3 an eine bestimmte Domain oder Aufgabe anpassen. Durch diesen Prozess werden die Präzision und Genauigkeit des Modells verbessert, indem seine Parameter mit Ihrem eigenen spezialisierten Dataset aktualisiert werden.

Dieser Leitfaden ist ein guter Ausgangspunkt, wenn Sie bei der Feinabstimmung Ihrer KI/ML-Arbeitslasten die detaillierte Kontrolle, Anpassung, Skalierbarkeit, Robustheit, Übertragbarkeit und Kosteneffizienz von verwaltetem Kubernetes benötigen.

Hintergrund

Wenn Sie TPUs in GKE mit Jax zum Feinabstimmen eines LLM verwenden, können Sie eine robuste, produktionsreife Lösung zum Feinabstimmen mit allen Vorteilen von verwaltetem Kubernetes erstellen.

Gemma

Gemma ist eine Reihe offen verfügbarer, einfacher und auf generativer KI/ML basierender multimodaler Modelle, die unter einer offenen Lizenz veröffentlicht wurden. Diese KI-Modelle können in Ihren Anwendungen, Geräten, Mobilgeräten oder gehosteten Diensten ausgeführt werden. Gemma 3 führt Multimodalität ein und unterstützt Vision-Language-Eingaben und Textausgaben. Das Modell kann Kontextfenster mit bis zu 128.000 Tokens verarbeiten und unterstützt über 140 Sprachen. Gemma 3 bietet auch verbesserte Mathematik-, Schlussfolgerungs- und Chatfunktionen, einschließlich strukturierter Ausgaben und Funktionsaufrufen.

Sie können die Gemma-Modelle zur Textgenerierung verwenden. Sie können diese Modelle jedoch auch für spezielle Aufgaben optimieren.

Weitere Informationen finden Sie in der Gemma-Dokumentation.

TPUs

TPUs sind anwendungsspezifische integrierte Schaltkreise (Application-Specific Integrated Circuits, ASICs), die von Google speziell entwickelt wurden, um das maschinelle Lernen und die KI-Modelle zu beschleunigen, die mit Frameworks wie TensorFlow, PyTorch und JAX erstellt wurden.

Bevor Sie TPUs in GKE verwenden, sollten Sie den folgenden Lernpfad durcharbeiten:

  1. Informationen zur aktuellen Verfügbarkeit von TPU-Versionen finden Sie unter Cloud TPU-Systemarchitektur.
  2. TPUs in GKE

JAX

JAX ist ein leistungsstarkes Framework für maschinelles Lernen, das für die Verwendung mit TPUs und GPUs entwickelt wurde. JAX bietet eine API zum Erstellen und Trainieren von Machine-Learning-Modellen.

Weitere Informationen finden Sie im JAX-Repository.

Ziele

Diese Anleitung umfasst die folgenden Schritte:

  1. Erstellen Sie einen GKE Autopilot- oder Standardcluster mit der empfohlenen TPU-Topologie anhand der Modelleigenschaften. In dieser Anleitung führen Sie das Fine-Tuning für Knotenpools mit einem einzelnen Host durch.
  2. Fügen Sie einem Cloud Storage-Bucket Daten hinzu und stellen Sie eine Verbindung zum Container über Cloud Storage FUSE her.
  3. Stellen Sie den Job zur Feinabstimmung des LLM in GKE bereit.
  4. Job zum Feinabstimmen überwachen und Logs ansehen

Hinweis

  • Melden Sie sich in Ihrem Google Cloud -Konto an. Wenn Sie mit Google Cloudnoch nicht vertraut sind, erstellen Sie ein Konto, um die Leistungsfähigkeit unserer Produkte in der Praxis sehen und bewerten zu können. Neukunden erhalten außerdem ein Guthaben von 300 $, um Arbeitslasten auszuführen, zu testen und bereitzustellen.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • Prüfen Sie, ob Sie die folgenden Rollen für das Projekt haben: roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    Rollen prüfen

    1. Rufen Sie in der Google Cloud Console die Seite IAM auf.

      IAM aufrufen
    2. Wählen Sie das Projekt aus.
    3. Suchen Sie in der Spalte Hauptkonto nach allen Zeilen, in denen Sie oder eine Gruppe, zu der Sie gehören, angegeben sind. Fragen Sie Ihren Administrator, zu welchen Gruppen Sie gehören.

    4. Prüfen Sie in allen Zeilen, in denen Sie angegeben oder enthalten sind, die Spalte Rolle, um zu sehen, ob die Liste der Rollen die erforderlichen Rollen enthält.

    Rollen zuweisen

    1. Rufen Sie in der Google Cloud Console die Seite IAM auf.

      IAM aufrufen
    2. Wählen Sie das Projekt aus.
    3. Klicken Sie auf Zugriffsrechte erteilen.
    4. Geben Sie im Feld Neue Hauptkonten Ihre Nutzer-ID ein. Das ist in der Regel die E‑Mail-Adresse eines Google-Kontos.

    5. Klicken Sie auf Rolle auswählen und suchen Sie nach der Rolle.
    6. Klicken Sie auf Weitere Rolle hinzufügen, wenn Sie weitere Rollen zuweisen möchten.
    7. Klicken Sie auf Speichern.
  • Prüfen Sie, ob Sie ein ausreichendes Kontingent für 16 TPU Trillium (v6e)-Chips haben. In dieser Anleitung verwenden Sie eine Knotenpoolkonfiguration, für die 16 Chips und On-Demand-Instanzen erforderlich sind.
  • Prüfen Sie, ob Sie ein Docker-Repository haben. Wenn Sie keines haben, erstellen Sie ein Standard-Repository in Artifact Registry.

Umgebung vorbereiten

In dieser Anleitung verwenden Sie Cloud Shell zum Verwalten von Ressourcen, die auf Google Cloudgehostet werden. Die Software, die Sie für diese Anleitung benötigen, ist in Cloud Shell vorinstalliert, einschließlich kubectl und Google Cloud CLI.

So richten Sie Ihre Umgebung mit Cloud Shell ein:

  1. Starten Sie in der Google Cloud Console eine Cloud Shell-Sitzung und klicken Sie auf Cloud Shell aktivieren (Symbol für die Cloud Shell-Aktivierung). Dadurch wird im unteren Bereich der Google Cloud Console eine Sitzung gestartet.

  2. Legen Sie die Standardumgebungsvariablen fest:

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    Ersetzen Sie die folgenden Werte:

    • PROJECT_ID: Ihre Google Cloud Projekt-ID.
    • CLUSTER_NAME: der Name Ihres GKE-Cluster.
    • CONTROL_PLANE_LOCATION: die Compute Engine-Region, in der sich Ihr GKE-Cluster und Ihre TPU-Knoten befinden. Die Region muss Zonen enthalten, in denen TPU Trillium-Maschinentypen (v6e) verfügbar sind.
    • ZONE: eine Zone in der ausgewählten Region CONTROL_PLANE_LOCATION, in der TPU Trillium (v6e)-Maschinentypen verfügbar sind. Führen Sie den folgenden Befehl aus, um Zonen aufzulisten, in denen TPU Trillium (v6e) verfügbar ist:

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME: Der Name des Cloud Storage-Bucket, der Ihre Trainingsdaten enthält.

  3. Klonen Sie das Beispiel-Repository:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. Wechseln Sie zum Arbeitsverzeichnis:

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Google Cloud -Ressourcen erstellen und konfigurieren

In diesem Abschnitt erstellen und konfigurieren Sie Google Cloud -Ressourcen.

GKE-Cluster erstellen

Sie können ein LLM auf TPUs in einem GKE-Cluster im Autopilot- oder Standardmodus abstimmen. Für eine vollständig verwaltete Kubernetes-Umgebung empfehlen wir die Verwendung eines Autopilot-Clusters. Informationen zum Auswählen des GKE-Betriebsmodus, der für Ihre Arbeitslasten am besten geeignet ist, finden Sie unter GKE-Betriebsmodus auswählen.

Autopilot

Einen GKE Autopilot-Cluster erstellen, der die Workload Identity-Föderation für GKE verwendet und für den Cloud Storage FUSE aktiviert ist.

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

Die Erstellung eines Clusters kann einige Minuten dauern.

Standard

  1. Erstellen Sie einen regionalen GKE Standard-Cluster, der die Workload Identity-Föderation für GKE verwendet und für den Cloud Storage FUSE aktiviert ist.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    Die Erstellung eines Clusters kann einige Minuten dauern.

  2. So erstellen Sie einen Knotenpool mit einem Host:

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

GKE erstellt einen TPU Trillium-Knotenpool mit einer 1x1-Topologie und einem Knoten. Das Flag --workload-metadata=GKE_METADATA konfiguriert den Knotenpool für die Verwendung des GKE-Metadatenservers.

JobSet installieren

  1. Konfigurieren Sie kubectl für die Kommunikation mit Ihrem Cluster:

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. Installieren Sie die neueste veröffentlichte Version von JobSet:

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Ersetzen Sie JOBSET_VERSION durch die neueste veröffentlichte Version von JobSet. Beispiel: v0.11.0

  3. JobSet-Installation prüfen:

    kubectl get pods -n jobset-system
    

    Die Ausgabe sieht etwa so aus:

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    Möglicherweise müssen Sie weitere Knoten hinzufügen, wenn JobSet auf Ressourcen wartet.

Cloud Storage FUSE konfigurieren

Zum Feinabstimmen des LLM müssen Sie Trainingsdaten bereitstellen. In dieser Anleitung verwenden Sie das TinyStories-Dataset von Hugging Face. Dieses Dataset enthält von GPT-3.5 und GPT-4 synthetisch generierte Kurzgeschichten mit einem begrenzten Wortschatz.

In diesem Abschnitt wird beschrieben, wie Sie Cloud Storage FUSE so konfigurieren, dass Daten aus einem Cloud Storage-Bucket gelesen werden.

  1. Dataset herunterladen:

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. Laden Sie die Daten in einen neuen Cloud Storage-Bucket hoch:

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. Damit Ihre Arbeitslast Daten über Cloud Storage FUSE lesen kann, erstellen Sie ein Kubernetes-Dienstkonto (KSA) und fügen Sie die erforderlichen Berechtigungen hinzu. Führen Sie das Skript permissionsetup.sh aus:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    Nachdem Sie dieses Skript ausgeführt haben, werden die folgenden Ressourcen in IhremGoogle Cloud -Projekt und GKE-Cluster konfiguriert:

    • Ein neues IAM-Dienstkonto mit dem Namen gcs-fuse-sa wird in Ihrem Projekt erstellt.
    • Dem erstellten Google Cloud Dienstkonto (GSA) (gcs-fuse-sa) wird die Rolle roles/storage.objectViewer für den Cloud Storage-Bucket zugewiesen, der durch ${GCS_BUCKET_NAME} angegeben wird. Mit dieser Berechtigung kann die GSA Objekte aus dem Bucket lesen.
    • Ein neues KSA mit dem Namen jaxserviceaccount wird im Namespace default in Ihrem GKE-Cluster erstellt.
    • Die IAM-Richtlinie des GSA wird aktualisiert, um dem KSA die Rolle roles/iam.workloadIdentityUser zuzuweisen. Mit dieser Berechtigung kann das KSA die Identität des GSA übernehmen.
    • Das KSA wird annotiert, um es mit dem GSA zu verknüpfen. Diese Annotation teilt GKE mit, welches GSA das KSA mithilfe von Workload Identity annehmen soll.

      Alle Pods, die im Namespace default Ihres GKE-Clusters ausgeführt werden und das Dienstkonto jaxserviceaccount verwenden, können sich jetzt als gcs-fuse-sa-GSA authentifizieren. Diese Pods haben Lesezugriff auf die im Bucket gs://${GCS_BUCKET_NAME} gespeicherten Objekte. Das ist wichtig, damit der Job zum Feinabstimmen über Cloud Storage FUSE auf das Dataset zugreifen kann.

Feinabstimmungsskript erstellen

In diesem Abschnitt sehen Sie sich das Trainingsskript an, mit dem ein Gemma 3-Modell feinabgestimmt wird. In diesem Script wird Gemma3Tokenizer verwendet.

Sehen Sie sich das folgende Gemma3LLMTrain.py-Feinabstimmungsskript an:

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

In diesem Skript gilt Folgendes:

  • Ein Gemma3Tokenizer wandelt Textdaten in Tokens um, die das Modell verarbeiten kann.
  • Die Funktion load_and_preprocess_data liest die Trainingsdaten aus einer Datei, teilt sie in einzelne Geschichten auf und verwendet den Tokenizer, um den Text in gepaddete Sequenzen von Tokens zu konvertieren.
  • Die Funktion generate_text verwendet das Modell, seine Parameter und einen Prompt, um Text zu generieren.
  • Die Funktion train_step definiert eine einzelne Trainingsiteration, die den Forward-Pass, die Verlustberechnung (mit Kreuzentropie), die Gradientenberechnung und die Parameteraktualisierungen umfasst.
  • Die Funktion train_model durchläuft das Dataset für eine bestimmte Anzahl von Epochen und ruft die Funktion train_step für jeden Batch auf.
  • Die Funktion run_training orchestriert den gesamten Prozess zum Laden von Daten, Initialisieren des Gemma 3-Modells (Gemma3_270M) und des Optimierers, Laden vortrainierter Parameter, Einrichten von Data Sharding für die parallele Verarbeitung, Ausführen einer Testgenerierung, Ausführen der Trainingsschleife und Ausführen einer finalen Textgenerierung, um die Auswirkungen des Feinabstimmens zu demonstrieren.
  • Das Skript verwendet die argparse-Bibliothek, um Befehlszeilenargumente für die Parameter maxlen, batch_size und datacount zu akzeptieren.

Nachdem Sie das Feinabstimmungsskript untersucht haben, können Sie es in einen Container packen, um es in GKE auszuführen.

Feinabstimmungsskript containerisieren

Bevor Sie das Feinabstimmungsskript in einem GKE-Cluster ausführen, müssen Sie es in einen Container packen. In dieser Anleitung wird ein JAX AI-Image als Basis-Image verwendet.

  1. Öffnen Sie die Datei Dockerfile im selben Verzeichnis wie die Datei Gemma3LLMTrain.py:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Mit diesem Dockerfile werden die erforderlichen Abhängigkeiten installiert und die Datei Gemma3LLMTrain.py in den Container kopiert.

  2. Erstellen Sie das Docker-Image und übertragen Sie es per Push in ein Image-Repository:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Ersetzen Sie REPOSITORY_NAME durch den Namen Ihres Artifact Registry-Repositorys.

  3. Fügen Sie dem Dienstkonto Rollenbindungen hinzu:

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

Nachdem sich das Image im Repository befindet, können Sie den Job zum Feinabstimmen in einem GKE-Cluster bereitstellen.

LLM-Job für die Feinabstimmung bereitstellen

In diesem Abschnitt erfahren Sie, wie Sie den LLM-Abstimmungsjob in Ihrem GKE-Cluster bereitstellen.

  1. Öffnen Sie das Manifest training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Wenden Sie das Manifest an:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

GKE erstellt einen Job, der einen Pod auf einem TPU Trillium-Knoten (v6e) startet. In diesem Pod wird das Python-Script zum Feinabstimmen ausgeführt, das über Cloud Storage FUSE auf die Feinabstimmungsdaten aus dem angegebenen Cloud Storage-Bucket zugreift, der unter dem Pfad /data eingebunden ist. Anschließend wird das Gemma-Modell optimiert.

Trainingsjob überwachen

In diesem Abschnitt überwachen Sie den Fortschritt des Fine-Tuning-Jobs und seine Leistung.

Fortschritt der Feinabstimmung ansehen

  1. Listen Sie die Pods auf:

    # Find the Pods
    kubectl get pods
    
  2. Folgen Sie der Logausgabe:

    kubectl logs -f pods/POD_NAME
    

    Ersetzen Sie POD_NAME durch den Namen Ihres Clusters.

    Die Ausgabe sieht etwa so aus:

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. Ausgabe analysieren:

    • Die Global device count: 1-Linie gibt die verwendeten TPU-Kerne an.
    • Das Modell generiert vor diesem Feinabstimmungslauf sinnvollen Text, da es aus einem vortrainierten Prüfpunkt geladen wird.
    • Die nach dem Fine-Tuning generierte Ausgabe ähnelt eher dem Beginn einer Kurzgeschichte, was darauf hindeutet, dass das Modell aus dem neuen Dataset lernt.
    • Durch das Feinabstimmen des vollständigen Datasets sollten noch bessere Ergebnisse erzielt werden.

Messwerte beobachten

Sehen Sie sich die Leistung des Feinabstimmungsjobs an, indem Sie die TPU- und CPU-Messwerte prüfen. Wenn Sie Beobachtbarkeitsmesswerte für Ihren Cluster aufrufen möchten, führen Sie die Schritte unter Messwerte für Cluster- und Arbeitslast-Beobachtbarkeit aufrufen aus.

Alternative Konfigurationen für die Feinabstimmung

In diesem Abschnitt werden alternative Konfigurationen für Ihre Feinabstimmungsarbeitslast beschrieben.

Modellauswahl

In diesem Tutorial wurde das Gemma3_270M-Modell verwendet, ein kleines Modell, das in einen TPU-Trillium-Knotenpool (v6e) mit einem einzelnen Host passt. Für größere Modelle, die mehr Arbeitsspeicher und Rechenressourcen für die Feinabstimmung erfordern, können Sie Knotenpoolkonfigurationen mit mehreren Hosts oder mehreren Slices verwenden.

Eine vollständige Liste der verfügbaren Modelle finden Sie in der Gemma-Dokumentation.

Knotenpoolkonfigurationen

In dieser Anleitung wurde ein Knotenpool mit einem einzelnen Host verwendet. Je nach Bedarf können Sie auch TPU-Slice-Knotenpools mit mehreren Hosts oder Multislice-Knotenpools erstellen.

Auf den folgenden Tabs wird gezeigt, wie Sie Knotenpools mit mehreren Hosts und mehreren Slices erstellen:

Mehrere Hosts

  1. Führen Sie in Cloud Shell den folgenden Befehl aus:

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    GKE erstellt einen TPU Trillium-Knotenpool mit einer 2x4-Topologie und zwei Knoten.

  2. Öffnen Sie die Jobdefinition training_multihost_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. Stellen Sie den Job für die Feinabstimmung bereit:

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

Multislice

  1. Führen Sie in Cloud Shell den folgenden Befehl aus:

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    GKE erstellt zwei TPU Trillium-Knotenpools. Jeder Knotenpool hat eine 2x4-Topologie und zwei Knoten.

  2. Öffnen Sie die Jobdefinition training_multislice_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. Stellen Sie den Job für die Feinabstimmung bereit:

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

Leistungsanalyse und ‑optimierung

Mit XProf können Sie die Leistung Ihres Machine Learning-Feinabstimmungsprozesses analysieren und optimieren. XProf ist eine Reihe von Tools, mit denen ML-Arbeitslasten, die mit JAX, TensorFlow oder PyTorch/XLA erstellt wurden, profiliert und untersucht werden können. Mit XProf können Sie Ihre Modelle und die Einrichtung des Trainings optimieren, um die Effizienz zu steigern und das Training zu beschleunigen. Dazu werden Ausführungstraces, die Speichernutzung und andere Daten angezeigt.

In diesem Abschnitt führen Sie die folgenden Schritte aus, um die Leistung Ihrer Arbeitslast für das Feinabstimmen mit XProf zu analysieren:

  • Installieren Sie das Paket xprof: Ändern Sie Ihr Trainingsskript, um den XProf-Server zu starten.
  • Ändern Sie Ihr Kubernetes-Jobmanifest, um eine Volume-Bereitstellung für XProf-Logs einzufügen.
  • Gewähren Sie dem Dienstkonto Berechtigungen zum Schreiben von XProf-Logs in einen Cloud Storage-Bucket.
  • Führen Sie XProf in Ihrem Pod aus und richten Sie die Portweiterleitung ein, um auf das XProf-Dashboard zuzugreifen.

XProf-Paket installieren

  1. Wechseln Sie in das Verzeichnis mit den XProf-Beispielen:

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Erstellen Sie das Docker-Image und übertragen Sie es per Push in ein Image-Repository:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Ersetzen Sie REPOSITORY_NAME durch den Namen Ihres Artifact Registry-Repositorys.

  3. Führen Sie das Skript Dockerfile aus:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Mit diesem Dockerfile werden XProf-Abhängigkeiten installiert.

Kopieren Sie Ihr Fine-Tuning-Skript in den Container.

In diesem Abschnitt erstellen und wenden Sie ein Kubernetes-Jobmanifest an, das die erforderlichen Volume-Bereitstellungen für XProf-Logs enthält.

  1. Öffnen Sie die Jobdefinition training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Wenden Sie das Manifest an:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

Dem Dienstkonto Berechtigungen zum Schreiben von XProf-Logs gewähren

  1. Fügen Sie die Rolle "roles/storage.objectUser" hinzu, damit das Dienstkonto schreiben und lesen kann:

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    Ersetzen Sie Folgendes:

    • GSA_NAME: Der Name des Google-Dienstkontos, dem die Rolle zugewiesen werden soll.
    • XPROF_GCS_BUCKET_NAME: Der Name des Buckets, dem die Rolle zugewiesen werden soll.
  2. Führen Sie XProf in Ihrem Pod aus:

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    Ersetzen Sie POD_NAME durch den Namen Ihres Clusters.

Auf das XProf-Dashboard zugreifen

  1. Richten Sie die Portweiterleitung zum XProf-Server im Pod ein:

    kubectl port-forward POD_NAME 9001:9001
    
  2. Geben Sie Folgendes in die Adressleiste Ihres Browsers ein:

    http://localhost:9001/
    

    Der XProf Trace Viewer wird geöffnet.

  3. Klicken Sie im TensorBoard-Fenster auf Profil erfassen.

  4. Geben Sie im Feld Profildienst-URL(s) oder TPU-Name localhost:9002 ein.

  5. Wenn Sie weitere Details erfassen möchten, wählen Sie unter Host Trace (TraceMe) Level die Option verbose aus und aktivieren Sie das Python-Trace-Logging.

  6. Klicken Sie auf Aufnehmen, um das Dashboard aufzurufen.

    TensorBoard erfasst das Profil und ermöglicht es Ihnen, die Leistung des Trainingsskripts zu analysieren. Das Diagramm zeigt die Ausführungszeitachse für TPU- und CPU-Leistungsprofile:

Ein Beispiel für den XProf-Trace-Viewer, der ein Leistungsmatrixdiagramm zeigt

Weitere Profilerstellungsoptionen zur Analyse der Leistung Ihres Trainings-Workloads finden Sie in der JAX-Dokumentation unter Profiling computation.

Abstimmung in Produktionsumgebungen

In dieser Anleitung haben Sie gelernt, wie Sie das JAX-basierte Training in einer verteilten Umgebung testen. Für die optimierte LLM-Feinabstimmung in der Produktion verwenden Sie die Maxtext-Bibliothek. Wenn Sie sich für Diffusionsmodelle interessieren, verwenden Sie Maxdiffusion-Implementierungen.

Für Produktionsarbeitslasten mit langer Trainings- oder Feinabstimmungsdauer sollten Sie die Prüfpunktausführung von Arbeitslasten einrichten, um den Fortschrittsverlust bei einem Fehler zu minimieren. Weitere Informationen zum Einrichten von mehrstufigen Prüfpunkten finden Sie unter Große Modelle für maschinelles Lernen in GKE mit mehrstufigen Prüfpunkten trainieren.

Bereinigen

Damit Ihrem Google Cloud-Konto die in dieser Anleitung verwendeten Ressourcen nicht in Rechnung gestellt werden, löschen Sie entweder das Projekt, das die Ressourcen enthält, oder Sie behalten das Projekt und löschen die einzelnen Ressourcen.

Einzelne Ressourcen löschen

Damit Ihrem Google Cloud -Konto die in dieser Anleitung verwendeten Ressourcen nicht in Rechnung gestellt werden, können Sie entweder das Projekt löschen, das die Ressourcen enthält, oder das Projekt beibehalten und die einzelnen Ressourcen mit den folgenden Befehlen löschen:

  1. Löschen Sie die in dieser Anleitung erstellten Ressourcen:

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. Wenn Sie die von XProf generierten Daten nicht benötigen, entfernen Sie den von XProf verwendeten Cloud Storage-Bucket:

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

Nächste Schritte